? ? ? ? CGAN通過(guò)在生成器和判別器中均使用標(biāo)簽信息進(jìn)行訓(xùn)練,不僅能產(chǎn)生特定標(biāo)簽的數(shù)據(jù),還能夠提高生成數(shù)據(jù)的質(zhì)量;SGAN(Semi-Supervised GAN)通過(guò)使判別器/分類器重建標(biāo)簽信息來(lái)提高生成數(shù)據(jù)的質(zhì)量。既然這兩種思路都可以提高生成數(shù)據(jù)的質(zhì)量,于是ACGAN綜合了以上兩種思路,既使用標(biāo)簽信息進(jìn)行訓(xùn)練,同時(shí)也重建標(biāo)簽信息,結(jié)合CGAN和SGAN的優(yōu)點(diǎn),從而進(jìn)一步提升生成樣本的質(zhì)量,并且還能根據(jù)指定的標(biāo)簽相應(yīng)的樣本。
1. ACGAN的網(wǎng)絡(luò)結(jié)構(gòu)為:

? ? ? ? 生成器的輸入包含C_vector和Noise_data兩個(gè)部分,其中C_vector為訓(xùn)練數(shù)據(jù)標(biāo)簽信息的One-hot編碼張量,其形狀為:(batch_size, num_class) ;Noise_data的形狀為:(batch_size, latent_dim)。然后將兩者進(jìn)行拼接,拼接完成后,得到的輸入張量為:(batch_size, num_class + latent_dim)。生成器的的輸出張量為:(batch_size, channel, Height, Width)。
? ? ? ? 判別器的輸入為:(batch_size, channel, Height, Width); 判別的器的輸出為兩部分,一部分是源數(shù)據(jù)真假的判斷,形狀為:(batch_size, 1),一部分是輸入數(shù)據(jù)的分類結(jié)果,形狀為:(batch_size, class_num)。因此判別器的最后一層有兩個(gè)并列的全連接層,分別得到這兩部分的輸出結(jié)果,即判別器的輸出有兩個(gè)張量(真假判斷張量和分類結(jié)果張量)。
2. ACGAN的損失函數(shù):
? ? ? ? 對(duì)于判別器而言,既希望分類正確,又希望能正確分辨數(shù)據(jù)的真假;對(duì)于生成器而言,也希望能夠分類正確,當(dāng)時(shí)希望判別器不能正確分辨假數(shù)據(jù)。
判別器的損失函數(shù):??
真假判斷損失:??
分類損失:?
D_real, C_real = Discriminator( real_imgs)? ? ? ? ?# real_img 為輸入的真實(shí)訓(xùn)練圖片
D_real_loss = torch.nn.BCELoss(D_real, Y_real)? ? ? ? ? #? Y_real為真實(shí)數(shù)據(jù)的標(biāo)簽,真數(shù)據(jù)都為-1,假數(shù)據(jù)都為+1
C_real_loss = torch.nn.CrossEntropyLoss(C_real, Y_vec)? ? ? ? # Y_vec為訓(xùn)練數(shù)據(jù)One-hot編碼的標(biāo)簽張量
gen_imgs = Generator(noise, Y_vec)
D_fake, C_fake = Discriminator(gen_imgs)
D_fake_loss = torch.nn.BCELoss(D_fake, Y_fake)
C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)
D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss
生成器的損失函數(shù):??
真假判斷損失:
分類損失:
gen_imgs = Generator(noise, Y_vec)
D_fake, C_fake = Discriminator(gen_imgs)
D_fake_loss = torch.nn.BCELoss(D_fake, Y_real)
C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)
G_loss = D_fake_loss + C_fake_loss