ACGAN-半監(jiān)督式GAN

? ? ? ? CGAN通過(guò)在生成器和判別器中均使用標(biāo)簽信息進(jìn)行訓(xùn)練,不僅能產(chǎn)生特定標(biāo)簽的數(shù)據(jù),還能夠提高生成數(shù)據(jù)的質(zhì)量;SGAN(Semi-Supervised GAN)通過(guò)使判別器/分類器重建標(biāo)簽信息來(lái)提高生成數(shù)據(jù)的質(zhì)量。既然這兩種思路都可以提高生成數(shù)據(jù)的質(zhì)量,于是ACGAN綜合了以上兩種思路,既使用標(biāo)簽信息進(jìn)行訓(xùn)練,同時(shí)也重建標(biāo)簽信息,結(jié)合CGAN和SGAN的優(yōu)點(diǎn),從而進(jìn)一步提升生成樣本的質(zhì)量,并且還能根據(jù)指定的標(biāo)簽相應(yīng)的樣本。

1. ACGAN的網(wǎng)絡(luò)結(jié)構(gòu)為:

ACGAN的網(wǎng)絡(luò)結(jié)構(gòu)框圖

? ? ? ? 生成器輸入包含C_vector和Noise_data兩個(gè)部分,其中C_vector為訓(xùn)練數(shù)據(jù)標(biāo)簽信息的One-hot編碼張量,其形狀為:(batch_size, num_class) ;Noise_data的形狀為:(batch_size, latent_dim)。然后將兩者進(jìn)行拼接,拼接完成后,得到的輸入張量為:(batch_size, num_class + latent_dim)。生成器的的輸出張量為:(batch_size, channel, Height, Width)。

? ? ? ? 判別器輸入為:(batch_size, channel, Height, Width); 判別的器的輸出為兩部分,一部分是源數(shù)據(jù)真假的判斷,形狀為:(batch_size, 1),一部分是輸入數(shù)據(jù)的分類結(jié)果,形狀為:(batch_size, class_num)。因此判別器的最后一層有兩個(gè)并列的全連接層,分別得到這兩部分的輸出結(jié)果,即判別器的輸出有兩個(gè)張量(真假判斷張量和分類結(jié)果張量)。

2. ACGAN的損失函數(shù):

? ? ? ? 對(duì)于判別器而言,既希望分類正確,又希望能正確分辨數(shù)據(jù)的真假;對(duì)于生成器而言,也希望能夠分類正確,當(dāng)時(shí)希望判別器不能正確分辨假數(shù)據(jù)。

判別器的損失函數(shù):??L_{D}=L_{S}+  L_{C}

真假判斷損失:??L_{S} =E[logP(S=real|x_{real} )] + E[logP(S=fake|x_{fake} )]

分類損失:?L_{C} =E[logP(C=c|x_{real} )] + E[logP(C=c|x_{fake} )]

D_real, C_real = Discriminator( real_imgs)? ? ? ? ?# real_img 為輸入的真實(shí)訓(xùn)練圖片

D_real_loss = torch.nn.BCELoss(D_real, Y_real)? ? ? ? ? #? Y_real為真實(shí)數(shù)據(jù)的標(biāo)簽,真數(shù)據(jù)都為-1,假數(shù)據(jù)都為+1

C_real_loss = torch.nn.CrossEntropyLoss(C_real, Y_vec)? ? ? ? # Y_vec為訓(xùn)練數(shù)據(jù)One-hot編碼的標(biāo)簽張量

gen_imgs = Generator(noise, Y_vec)

D_fake, C_fake = Discriminator(gen_imgs)

D_fake_loss = torch.nn.BCELoss(D_fake, Y_fake)

C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)

D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss

生成器的損失函數(shù):??L_{D}=L_{C}- L_{S}

真假判斷損失:L_{S} = E[logP(S=fake|x_{fake} )]

分類損失:L_{C} =E[logP(C=c|x_{fake} )]

gen_imgs = Generator(noise, Y_vec)

D_fake, C_fake = Discriminator(gen_imgs)

D_fake_loss = torch.nn.BCELoss(D_fake, Y_real)

C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)

G_loss = D_fake_loss + C_fake_loss

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容