GAN生成對抗網(wǎng)絡(luò)從入門到實踐——入門級

自2014年Ian Goodfellow提出生成對抗網(wǎng)絡(luò)(GAN)的概念后,生成對抗網(wǎng)絡(luò)變成為了學(xué)術(shù)界的一個火熱的研究熱點,Yann LeCun更是稱之為”過去十年間機器學(xué)習(xí)領(lǐng)域最讓人激動的點子”.

生成對抗網(wǎng)絡(luò)包括一個生成器(Generator,簡稱G)生成數(shù)據(jù),一個鑒別器(Discriminator,簡稱D)來鑒別真實數(shù)據(jù)和生成數(shù)據(jù),兩者同時訓(xùn)練,直到達到一個納什均衡,生成器生成的數(shù)據(jù)與真實樣本無差別,鑒別器也無法正確的區(qū)分生成數(shù)據(jù)和真實數(shù)據(jù).

0.入門姿勢——生成模型

生成模型故名思議就是已知模型,用來生成適應(yīng)該模型的數(shù)據(jù),那么它有什么應(yīng)用場景呢?

主要的應(yīng)用場景有兩種:

  1. 當(dāng)我們擁有大量的數(shù)據(jù),例如圖像、語音、文本等,如果生成模型可以幫助我們模擬這些高維數(shù)據(jù)的分布,那么對很多應(yīng)用將大有裨益。

  2. 針對數(shù)據(jù)量缺乏的場景,生成模型則可以幫助生成數(shù)據(jù),提高數(shù)據(jù)數(shù)量,從而利用半監(jiān)督學(xué)習(xí)提升學(xué)習(xí)效率。語言模型(language model)是生成模型被廣泛使用的例子之一,通過合理建模,語言模型不僅可以幫助生成語言通順的句子,還在機器翻譯、聊天對話等研究領(lǐng)域有著廣泛的輔助應(yīng)用。

那么,如果有數(shù)據(jù)集S={x1,…xn},如何建立一個關(guān)于這個類型數(shù)據(jù)的生成模型呢?

最簡單的方法就是:假設(shè)這些數(shù)據(jù)的分布P{X}服從g(x;θ),在觀測數(shù)據(jù)上通過最大化似然函數(shù)得到θ的值,即最大似然法。

例如,我們知道一一些文本中有若干單詞,我們就可以用單詞出現(xiàn)的頻率作為這些數(shù)據(jù)的分布(如單詞“text”的概率0.3,“today”的概率為0.1),以這些概率來生成新的文檔。

GAN也是一種生成模型,不過是一種以于半監(jiān)督學(xué)習(xí)方式訓(xùn)練的模型,基于神經(jīng)網(wǎng)絡(luò),經(jīng)常被用在圖像處理和半監(jiān)督學(xué)習(xí)領(lǐng)域。

1.基本原理

GAN有一個生成器(Generator,簡稱G)生成數(shù)據(jù),一個鑒別器(Discriminator,簡稱D)鑒別數(shù)據(jù)是否與真實數(shù)據(jù)相似。

生成模型

鑒別器的作用G:盡最大努力區(qū)分生成器生成的數(shù)據(jù)和真實數(shù)據(jù)

生成器作用D:生成和真實數(shù)據(jù)幾乎沒有差距的數(shù)據(jù)

上述的博弈過程就基本上是GAN的原理了。

那么GAN的數(shù)學(xué)形式是怎樣的呢?

假設(shè)我們的生成模型是g(z),其中z是一個隨機噪聲,而g將這個隨機噪聲轉(zhuǎn)化為數(shù)據(jù)類型x,仍拿圖片問題舉例,這里g的輸出就是一張圖片。D是一個判別模型,對任何輸入x,D(x)的輸出是0-1范圍內(nèi)的一個實數(shù),用來判斷這個圖片是一個真實圖片的概率是多大。令Pr和Pg分別代表真實圖像的分布與生成圖像的分布.

鑒別器的作用效果:


鑒別器D

生成器的作用效果:


生成器G

整體效果大概是下面這樣:


迭代過程

圖中黑色虛線是真實數(shù)據(jù)的高斯分布,綠色的線是生成網(wǎng)絡(luò)學(xué)習(xí)到的偽造分布,藍色的線是判別網(wǎng)絡(luò)判定為真實圖片的概率,標(biāo)x的橫線代表服從高斯分布x的采樣空間,標(biāo)z的橫線代表服從均勻分布z的采樣空間??梢钥闯鯣就是學(xué)習(xí)了從z的空間到x的空間的映射關(guān)系。

2.GAN優(yōu)缺點及改進

2.1GAN的優(yōu)劣勢

優(yōu)勢

  • GANs是一種以半監(jiān)督方式訓(xùn)練分類器的方法,可以參考我們的NIPS paper相應(yīng)代碼.在你沒有很多帶標(biāo)簽的訓(xùn)練集的時候,你可以不做任何修改的直接使用我們的代碼,通常這是因為你沒有太多標(biāo)記樣本.我最近也成功的使用這份代碼與谷歌大腦部門在深度學(xué)習(xí)的隱私方面合寫了一篇論文
  • GANs可以比完全明顯的信念網(wǎng)絡(luò)(NADE,PixelRNN,WaveNet等)更快的產(chǎn)生樣本,因為它不需要在采樣序列生成不同的數(shù)據(jù).
  • GANs不需要蒙特卡洛估計來訓(xùn)練網(wǎng)絡(luò),人們經(jīng)常抱怨GANs訓(xùn)練不穩(wěn)定,很難訓(xùn)練,但是他們比訓(xùn)練依賴于蒙特卡洛估計和對數(shù)配分函數(shù)的玻爾茲曼機簡單多了.因為蒙特卡洛方法在高維空間中效果不好,玻爾茲曼機從來沒有拓展到像ImgeNet任務(wù)中.GANs起碼在ImageNet上訓(xùn)練后可以學(xué)習(xí)去畫一些以假亂真的狗
  • 相比于變分自編碼器, GANs沒有引入任何決定性偏置( deterministic bias),變分方法引入決定性偏置,因為他們優(yōu)化對數(shù)似然的下界,而不是似然度本身,這看起來導(dǎo)致了VAEs生成的實例比GANs更模糊.
  • 相比非線性ICA(NICE, Real NVE等,),GANs不要求生成器輸入的潛在變量有任何特定的維度或者要求生成器是可逆的.
  • 相比玻爾茲曼機和GSNs,GANs生成實例的過程只需要模型運行一次,而不是以馬爾科夫鏈的形式迭代很多次.

劣勢

  • 訓(xùn)練GAN需要達到納什均衡,有時候可以用梯度下降法做到,有時候做不到.我們還沒有找到很好的達到納什均衡的方法,所以訓(xùn)練GAN相比VAE或者PixelRNN是不穩(wěn)定的,但我認(rèn)為在實踐中它還是比訓(xùn)練玻爾茲曼機穩(wěn)定的多.
  • 它很難去學(xué)習(xí)生成離散的數(shù)據(jù),就像文本
  • 相比玻爾茲曼機,GANs很難根據(jù)一個像素值去猜測另外一個像素值,GANs天生就是做一件事的,那就是一次產(chǎn)生所有像素, 你可以用BiGAN來修正這個特性,它能讓你像使用玻爾茲曼機一樣去使用Gibbs采樣來猜測缺失值,

以上是GAN的發(fā)明者者回答網(wǎng)友問

2.2 GAN的改進

常見的改進深度卷積的對抗生成網(wǎng)絡(luò)(DC-GAN),在圖像中有著很重要的應(yīng)用

在圖像生成過程中,如何設(shè)計生成模型和判別模型呢?深度學(xué)習(xí)里,對圖像分類建模,刻畫圖像不同層次,抽象信息表達的最有效的模型是:CNN (convolutional neural network,卷積神經(jīng)網(wǎng)絡(luò))。

在CSDN上看到一個例子,會發(fā)現(xiàn)DC-GAN的優(yōu)化效果會好很多


DC-GAN優(yōu)化

3.實踐

我們需要準(zhǔn)備以下的東西:

  1. R:原始的真實數(shù)據(jù)
    我們使用一個貝爾分布作為真實的數(shù)據(jù),貝爾分布需要提供一個均值和標(biāo)準(zhǔn)差,代碼中選擇均值和標(biāo)準(zhǔn)差分別為0.4和1.25
# ##### DATA: Target data and generator input data

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian
  1. I:一個原始的噪聲,為生成器提供數(shù)據(jù)的來源
    對生成器的輸入也是隨機的,這是要使我們的工作變得困難一點,讓我們使用一個連續(xù)型均勻分布,而不是一個正態(tài)分布。這意味著我們的模型 G 不能簡單地轉(zhuǎn)移/縮放輸入數(shù)據(jù)來達到 R的效果,必須以非線性的方式重新調(diào)整數(shù)據(jù)。
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

  1. G:生成器——努力讓數(shù)據(jù)模仿真實數(shù)據(jù)
# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

  1. D:判別器——努力區(qū)分真實數(shù)據(jù)和生成的數(shù)據(jù)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

  1. 一個用來訓(xùn)練的循環(huán)——讓G和D對抗


    訓(xùn)練網(wǎng)絡(luò)

完整代碼

# Generative Adversarial Networks (GAN) example in PyTorch.
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

print("Using data [%s]" % (name))

# ##### DATA: Target data and generator input data

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian
超過2萬訓(xùn)練回合, 平均 G 的輸出過度 4.0, 但然后回來在一個相當(dāng)穩(wěn)定, 正確的范圍 (左)。同樣, 標(biāo)準(zhǔn)偏差最初下落在錯誤方向, 但然后上升到期望1.25 范圍 (正確), 匹配 R。
# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))
訓(xùn)練結(jié)果

經(jīng)過2萬訓(xùn)練回合, 平均 G 的輸出過度 4.0, 但然后回來在一個相當(dāng)穩(wěn)定, 正確的范圍 (左)。同樣, 標(biāo)準(zhǔn)偏差最初下落在錯誤方向, 但然后上升到期望1.25 范圍 (正確), 匹配 R。


圖片.png

由 G自動生成的最終分配

參考鏈接
到底什么是生成式對抗網(wǎng)絡(luò)
火熱的生成對抗網(wǎng)絡(luò)(GAN),你究竟好在哪里
Generative Adversarial Networks (GANs) in 50 lines of code (PyTorch)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容