論文研讀:SA-GAN

前言

SA-GAN( Self-Attention GAN)[1] 由 Google Brain出品,效果比之前的 SOTA 提高的有點多,并且還有開源代碼學習下:https://github.com/brain-research/self-attention-gan

PS. Attention Mechanism 相關(guān)的論文還沒怎么看,所以注意力的模塊原理等先不展開討論,主要討論SA-GAN中的自注意力模塊的設(shè)計。

問題以及 idea

SA-GAN 先從之前的SOTA的圖像生成模型[2] 出發(fā),發(fā)現(xiàn)之前的模型在生成紋理性強的類的圖片(如海洋、天空等)時效果好,但是生成結(jié)構(gòu)性強的圖片效果不好。

SA-GAN 針對的是圖像生成中卷積層的缺點:傳統(tǒng)的卷積層由于本身局部性的特點只能處理一定大小的局部相鄰的信息,如果僅使用卷積層對圖像進行建模,對圖像中長距離依賴的特征的建模效率不高(要使用多層卷積層才能處理到更大范圍、更高層的特征)。基于這個問題,SA-GAN 提出了自注意力層。

Self-Attention Module

自注意力模塊架構(gòu)如圖:

Self-Attention Module

對自注意力模塊的輸入圖片{x \in \mathbb{R}^{C \times N}}

此時 {x} 是將寬高兩個維度flatten成一個維度了({N = W \times H}
),將某一個位置(i,j)的通道上的向量看成該位置的擁有的feature,也就是{x_{i} \in \mathbb{R}^{C \times 1}},接著做如下的操作:

  • {x}通過函數(shù){f,g}。

  • 計算{s},其中{s_{ij} = f(x_i)^T g(x_j) }。

  • {s}按行做softmax歸一化得到{\beta},此時{\beta}的元素{\beta_{ji}}表示了第{i}位置的像素的feature 對第j位置的輸出的貢獻(權(quán)重,注意力)

  • {x}通過函數(shù){h}變換,再使用計算的注意力{\beta}{h(x)}相乘(也就是一個加權(quán)的過程),這樣輸出的每個像素都由原來全部像素點的特征組合而來。

  • 通過函數(shù){v}得到輸出。

  • 最后將注意力層的輸出乘上一個因子{\gamma}和輸入累加輸出。

{ y = \gamma o + x }

實現(xiàn)中:

  • 其中,四個函數(shù)的權(quán)重矩陣的大小為:
    {W_f \in \mathbb{R}^{\tilde{C} \times C }},{W_g \in \mathbb{R}^{\tilde{C} \times C }},{W_h \in \mathbb{R}^{\tilde{C} \times C }},{W_f \in \mathbb{R}^{C \times \tilde{C} }}

  • 實現(xiàn)中所有函數(shù)都由一個1x1的卷積層代替,{\tilde{C} = C / k},并且文中提到將{\tilde{C}}調(diào)整到{C/8} 時模型表現(xiàn)沒有下降,為了效率起見使用{k = 8}。

  • 其中自注意力層的輸出比例{\gamma}為可學習的參數(shù),初始化為0,由模型學習自注意力層應(yīng)該在原圖上修改多少細節(jié)。

設(shè)計思路闡述:

  • 使用1x1卷積層和{\tilde{C} = C / 8} 的設(shè)計應(yīng)該是減少計算量,{f,g,h}{v}就是分別起降維和升維的作用。

  • 輸入圖片{x}的每個位置在通道維度的向量作為該位置的特征向量{x_i}。

簡單的pytorch 實現(xiàn):

import torch.nn as nn 
import torch 
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

def sn_conv1x1(in_dim, out_dim):
    return spectral_norm(nn.Conv2d(in_dim, out_dim, 1,  1))

class SAModule(nn.Module):
    def __init__(self, dim, k = 8):
        super(SAModule, self).__init__()
        self.c = dim // k
        self.f = sn_conv1x1(dim, self.c)
        self.g = sn_conv1x1(dim, self.c)
        self.h = sn_conv1x1(dim, self.c)
        self.v = sn_conv1x1(self.c, dim)
        self.gamma = nn.Parameter(torch.tensor(0, dtype=torch.float),  requires_grad=True)
    
    def forward(self, x):
        N, C, H, W = x.size()
        num_locals = H * W
        theta = self.f(x).view(N, self.c, num_locals).permute(0,2,1)
        phi = self.g(x).view(N, self.c, num_locals)

        s = torch.bmm(theta, phi)
        atten = F.softmax(s, dim = 2)

        a = self.h(x).view(N, self.c, num_locals).permute(0,2,1)
        out = torch.bmm(atten, a).permute(0,2,1).view(N, self.c, H, W).contiguous()
        x0 = self.v(out)

        return x0 * self.gamma + x

使用的 trick

  • Spectral Normalization[3]: 譜歸一化在[2]中就已經(jīng)使用,而且[2][3]都是同一個作者提出,這里簡述一下:譜歸一化從WGAN出發(fā),是一種限制模型權(quán)重Lipschitz常數(shù)的方法,好處是計算代價低(僅需要一次迭代),而且不需要超參數(shù)。[2] 中的模型僅在判別器上使用,SA-GAN在生成器和判別器上都使用 Spectral Normalization??梢詤⒖枷?a target="_blank">這個簡單的實驗體會Spectral Normalization 的作用。

  • 判別器和生成器使用不同的學習率,判別器0.0004,生成器0.0001。

  • 在判別器上使用了[2] 中的 projection discriminator 結(jié)構(gòu),在生成器上使用了 Conditional BatchNorm 的結(jié)構(gòu)。

實驗、復(fù)現(xiàn)細節(jié)、超參數(shù)、實驗結(jié)論

  • 使用 Hinge Adv Loss。

  • 文章通過實驗證明:對 middel-to-high level 的 feature 使用 Self-Attention Layer 比對 low level 的 feature 上使用效果更好(即自注意力層盡量放置在靠近輸出層的卷積層之后)因為這樣Self-Attention 可以捕捉更高級特征和更大區(qū)域的feature 的聯(lián)系。

  • 文中又用殘差塊替換原有卷積層訓練新模型,結(jié)果效果沒有baseline好,證明自注意力層的加入起作用不是因為加深了模型、擴大了模型的容量。

Reference

  1. Self-Attention Generative Adversarial Networks
  2. cGAN with projection discriminator
  3. Spectral Normalization for Generative Adversarial Networks

后言

文章是為了方便自己理解而寫,所以難免有不清楚或錯誤之處、或者自創(chuàng)的方便理解的術(shù)語,如有錯誤,歡迎指正。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容