神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索——可微分搜索(Fair-DARTS)

小米實(shí)驗(yàn)室 AutoML 團(tuán)隊(duì)的NAS工作,論文題目:Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search。 針對(duì)現(xiàn)有DARTS框架在搜索階段訓(xùn)練過(guò)程中存在 skip-connection 富集現(xiàn)象,導(dǎo)致最終模型出現(xiàn)大幅度的性能損失問(wèn)題的問(wèn)題,提出了Sigmoid替代Softmax的方法,使搜索階段候選操作由競(jìng)爭(zhēng)關(guān)系轉(zhuǎn)化為合作關(guān)系。 并提出 0-1 loss 提高了架構(gòu)參數(shù)的二值性。

動(dòng)機(jī)

skip-connection 富集現(xiàn)象

skip-connection 富集現(xiàn)象

本文指出 skip connections 富集的原因主要有兩個(gè)方面:

skip connections 的不公平優(yōu)勢(shì)

在超網(wǎng)絡(luò)訓(xùn)練架構(gòu)參數(shù)過(guò)程中,兩個(gè)節(jié)點(diǎn)之間是八個(gè)操作同時(shí)作用的, skip connections 作為操作的其中一員,相較于其他的操作來(lái)講是起到了跳躍連接的作用。在ResNet 中已經(jīng)明確指出了跳躍連接在深層網(wǎng)絡(luò)的訓(xùn)練過(guò)程中中起到了良好的梯度疏通效果,進(jìn)而有效減緩了梯度消失現(xiàn)象。因此,在超網(wǎng)絡(luò)的搜索訓(xùn)練過(guò)程中,skip connections可以借助其他操作的關(guān)系達(dá)到疏通效果,使得,skip connections 相較于其他操作存在不公平優(yōu)勢(shì)。

softmax 的排外競(jìng)爭(zhēng)

由于 softmax 是典型的歸一化操作,是一種潛在的排外競(jìng)爭(zhēng)方式,致使一個(gè)架構(gòu)參數(shù)增大必然抑制其他參數(shù)。

部署訓(xùn)練的離散化差異(discretization discrepancy)

搜索過(guò)程結(jié)束后,在部署訓(xùn)練選取網(wǎng)絡(luò)架構(gòu)時(shí),直接將 softmax 后最大 α 值對(duì)應(yīng)的操作保留而拋棄其它的操作,從而使得選出的網(wǎng)絡(luò)結(jié)構(gòu)和原始包含所有結(jié)構(gòu)的超網(wǎng)二者的表現(xiàn)能力存在差距。離散化差異問(wèn)題主要在于兩點(diǎn),一方面Softmax歸一化八種操作參數(shù)后,DARTS 最后選擇時(shí)的 α 值基本都在 0.1 到 0.3 之間,另一方面判定好壞的范圍比較窄,因?yàn)椴煌僮?α 值的 top1 和 top2 可能差距特別小,例如 0.26 和 0.24,很難說(shuō) 0.26 就一定比 0.24 好,如下圖所示:

image

方法

sigmoid 函數(shù)替換 softmax


class Network(nn.Module):

    def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3,parse_method='darts', op_threshold=None):
        pass

    def forward(self, input):
        s0 = s1 = self.stem(input)
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                weights = F.sigmoid(self.alphas_reduce) # sigmoid 替換softmax
            else:
                weights = F.sigmoid(self.alphas_normal) # sigmoid 替換softmax
            s0, s1 = s1, cell(s0, s1, weights)
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0),-1))
        return logits

0-1 損失函數(shù)

l2 0-1 損失函數(shù)

L_{0-1}=-\frac{1}{N} \sum_{i}^{N}\left(\sigma\left(\alpha_{i}\right)-0.5\right)^{2}

l1 0-1 損失函數(shù)

L_{0-1}^{\prime}=-\frac{1}{N} \sum_{i}^{N}\left|\left(\sigma\left(\alpha_{i}\right)-0.5\right)\right|

L_{\text {total}}=\mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right)+w_{0-1} L_{0-1}


# l2
class ConvSeparateLoss(nn.modules.loss._Loss):
    """Separate the weight value between each operations using L2"""
    def __init__(self, weight=0.1, size_average=None, ignore_index=-100,reduce=None, reduction='mean'):
        super(ConvSeparateLoss, self).__init__(size_average, reduce, reduction)
        self.ignore_index = ignore_index
        self.weight = weight

    def forward(self, input1, target1, input2):
        loss1 = F.cross_entropy(input1, target1)
        loss2 = -F.mse_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
        return loss1 + self.weight*loss2, loss1.item(), loss2.item()

# l1
class TriSeparateLoss(nn.modules.loss._Loss):
    """Separate the weight value between each operations using L1"""
    def __init__(self, weight=0.1, size_average=None, ignore_index=-100,
                 reduce=None, reduction='mean'):
        super(TriSeparateLoss, self).__init__(size_average, reduce, reduction)
        self.ignore_index = ignore_index
        self.weight = weight

    def forward(self, input1, target1, input2):
        loss1 = F.cross_entropy(input1, target1)
        loss2 = -F.l1_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
        return loss1 + self.weight*loss2, loss1.item(), loss2.item()

使用上述損失函數(shù)就可以使得不同操作之間的差距增大,二者的 α 值要么逼近 0 要么逼近 1 如下圖曲線所示

導(dǎo)數(shù)可視化圖

實(shí)驗(yàn)

CIFAR-10

精度比較

FairDARTS 搜索 7 次均可得到魯棒性的結(jié)果:

CIFAR-10結(jié)果
skip connections 數(shù)量比較

DARTS 和 Fair DARTS 搜索出來(lái)的 cell 中所包含的 skip connections 數(shù)量比較:

skip connections 數(shù)量比較

ImageNet

精度比較

注意模型 A、B 是遷移比較,C、D 是直接搜索比較。


ImageNet結(jié)果
sigmoid 函數(shù)的共存性

熱力圖可看出使用 sigmoid 函數(shù)可讓其他操作和 skip connections 共存:

image
image

消融實(shí)驗(yàn)

去掉 Skip Connections

由于不公平的優(yōu)勢(shì)主要來(lái)自 Skip Connections,因此,搜索空間去掉 Skip Connections,那么即使在排他性競(jìng)爭(zhēng)中,其他操作也應(yīng)該期待公平競(jìng)爭(zhēng)。 去掉 Skip Connections搜索得到的最佳模型(96.88±0.18%)略高于DARTS(96.76±0.32%),但低于FairDARTS(97.41±0.14%)。 降低的精度表明足夠的 Skip Connections 確實(shí)對(duì)精度有益,因此也不能簡(jiǎn)單去掉。

image
0-1 損失函數(shù)分析
0-1 損失函數(shù)消融實(shí)驗(yàn)
  • 如果去掉 0-1 損失函數(shù)會(huì)使得 α 值不再集中于兩端,不利于離散化;
  • 損失靈敏度,即通過(guò)超參來(lái)控制 w_{0-1} 損失函數(shù)的靈敏度

討論

  • 對(duì)于 skip connections 使用 dropout 可以減少了不公平性;

  • 對(duì)所有操作使用 dropout 同樣是有幫助的;

  • 早停機(jī)制同樣關(guān)鍵(相當(dāng)于是在不公平出現(xiàn)以前及時(shí)止損);

  • 限制 skip connections 的數(shù)量需要極大的人為先驗(yàn),因?yàn)橹灰薅?skip connections 的數(shù)量為 2,隨機(jī)搜索也能獲得不錯(cuò)的結(jié)果;

  • 高斯噪聲或許也能打破不公平優(yōu)勢(shì)(孕育出了后面的NoisyDARTS~)。

參考

[1] Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search
[2] DARTS+: Improved Differentiable Architecture Search with Early Stopping
[3] Noisy Differentiable Architecture Search
[4] Fair DARTS:公平的可微分神經(jīng)網(wǎng)絡(luò)搜索
[5] Fair darts代碼解析

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容