神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索——可微分搜索(PC-DARTS)

PC-DARTS原理及源碼解析

華為發(fā)表在ICLR 2020上的NAS工作,針對(duì)現(xiàn)有DARTS模型訓(xùn)練時(shí)需要 Large memory and computing 問(wèn)題,提出了 Partial Channel ConnectionEdge Normalization 的技術(shù),在搜索過(guò)程中更快更好

動(dòng)機(jī)

接著上面的P-DARTS來(lái)看,盡管上面可以在17 cells情況下單卡完成搜索,但妥協(xié)犧牲的是operation的數(shù)量,這明顯不是個(gè)優(yōu)秀的方案,故此文 Partially-Connected DARTS,致力于大規(guī)模節(jié)省計(jì)算量和memory,從而進(jìn)行快速且大batchsize的搜索。

貢獻(xiàn)點(diǎn)

  • 設(shè)計(jì)了基于channel的sampling機(jī)制,故每次只有小部分1/K channel的node來(lái)進(jìn)行operation search,減少了(K-1)/K 的memory,故batchsize可增大為K倍。

  • 為了解決上述channel采樣導(dǎo)致的不穩(wěn)定性,提出了 邊緣正規(guī)化(edge normalization),在搜索時(shí)通過(guò)學(xué)習(xí)edge-level超參來(lái)減少不確定性。

方法

PC-DARTS架構(gòu)

部分通道連接(Partial Channel Connection)

如上圖的上半部分,在所有的通道數(shù)K里隨機(jī)采樣 1/K 出來(lái),進(jìn)行 operation search,然后operation 混合后的結(jié)果與剩下的 (K-1)/K 通道數(shù)進(jìn)行 concat,公式表示如下:

f_{i, j}^{\mathrm{PC}}\left(\mathbf{x}_{i} ; \mathbf{S}_{i, j}\right)=\sum_{o \in \mathcal{O}} \frac{\exp \left\{\alpha_{i, j}^{o}\right\}}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left\{\alpha_{i, j}^{o^{\prime}}\right\}} \cdot o\left(\mathbf{S}_{i, j} * \mathbf{x}_{i}\right)+\left(1-\mathbf{S}_{i, j}\right) * \mathbf{x}_{i}

上述的“部分通道連接”操作會(huì)帶來(lái)一些正副作用:

  • 正作用:能減少operations選擇時(shí)的biases,弱化無(wú)參的子操作(Pooling, Skip-Connect)的作用。文中3.3節(jié)有這么一句話:當(dāng)proxy dataset非常難時(shí)(即ImageNet),往往一開(kāi)始都會(huì)累積很大權(quán)重在weight-free operation,故制約了其在ImageNet上直接搜索的性能。
  • 副作用:由于網(wǎng)絡(luò)架構(gòu)在不同iterations優(yōu)化是基于隨機(jī)采樣的channels,故最優(yōu)的edge連通性將會(huì)不穩(wěn)定。
class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    self.mp = nn.MaxPool2d(2,2)

    for primitive in PRIMITIVES:
      op = OPS[primitive](C //4, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.BatchNorm2d(C //4, affine=False))
      self._ops.append(op)

  def forward(self, x, weights):
    #channel proportion k=4(實(shí)驗(yàn)證明1/4性能最佳)
    dim_2 = x.shape[1]
    xtemp = x[ : , :  dim_2//4, :, :] # channel 0到1/4的輸入
    xtemp2 = x[ : ,  dim_2//4:, :, :] # channel 1/4到1的輸入
    temp1 = sum(w * op(xtemp) for w, op in zip(weights, self._ops)) # 僅1/4數(shù)據(jù)參與ops運(yùn)算
    #reduction cell 需要在concat之前添加pooling操作
    if temp1.shape[2] == x.shape[2]:
      ans = torch.cat([temp1,xtemp2],dim=1)
    else:
      ans = torch.cat([temp1,self.mp(xtemp2)], dim=1)
    ans = channel_shuffle(ans,4) # 一個(gè)cell完成后對(duì)channel進(jìn)行隨機(jī)打散,為下個(gè)cell做采樣準(zhǔn)備
    #ans = torch.cat([ans[ : ,  dim_2//4:, :, :],ans[ : , :  dim_2//4, :, :]],dim=1)
    #except channe shuffle, channel shift also works
    return ans

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups
    
    # reshape [batchsize, num_channels, height, width] 
    # -> [batchsize, groups,channels_per_group, height, width]
    x = x.view(batchsize, groups, 
        channels_per_group, height, width)
        # 打亂channel的操作(借助transpose后數(shù)據(jù)塊的stride發(fā)生變化,然后將其連續(xù)化)
    # 參考:https://www.cnblogs.com/aoru45/p/10974508.html
    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

邊緣正規(guī)化(Edge Normalization)

為了克服部分通道連接這個(gè)副作用,提出邊緣正規(guī)化(見(jiàn)上圖的下半部分),即把多個(gè)PC后的node輸入softmax權(quán)值疊加,類attention機(jī)制

\mathbf{x}_{j}^{\mathrm{PC}}=\sum_{o \in \mathcal{O}} \frac{\exp \left\{\alpha_{i, j}^{o}\right\}}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left\{\alpha_{i, j}^{o^{\prime}}\right\}} \cdot o\left(\mathbf{S}_{i, j} * \mathbf{x}_{i}\right)+\left(1-\mathbf{S}_{i, j}\right) * \mathbf{x}_{i}

\mathbf{x}_{j}^{\mathrm{PC}}=\sum_{i < j} \frac{\exp \left\{\beta_{i, j}\right\}}{\sum_{i^{\prime} < j} \exp \left\{\beta_{i^{\prime}, j}\right\}} \cdot f_{i, j}\left(\mathbf{x}_{i}\right)

由于edge 超參 \beta_{i, j} 在訓(xùn)練階段是共享的,故學(xué)習(xí)到的網(wǎng)絡(luò)更少依賴于不同iterations間的采樣到的channels,使得網(wǎng)絡(luò)搜索過(guò)程更穩(wěn)定。當(dāng)網(wǎng)絡(luò)搜索完畢,node間的operation選擇由operation-level和edge-level的參數(shù)相乘后共同決定。

weights_normal = [F.softmax(alpha, dim=-1) for alpha in alpha_normal]
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in alpha_reduce]
weights_edge_normal = [F.softmax(beta, dim=0) for beta in beta_normal]
weights_edge_reduce = [F.softmax(beta, dim=0) for beta in beta_reduce]


def parse(alpha, beta, k):
        ...  
    for edges, w in zip(alpha, beta):
        edge_max, primitive_indices = torch.topk((w.view(-1, 1) * edges)[:, :-1], 1) # ignore 'none'
    ...

實(shí)驗(yàn)結(jié)果

CIFAR-10

CIFAR-10結(jié)果

ImageNet

ImageNet結(jié)果

消融實(shí)驗(yàn)

消融實(shí)驗(yàn)

參考

[1] Yuhui Xu et al. ,PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search

[2] https://zhuanlan.zhihu.com/p/73740783

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容