小米實(shí)驗(yàn)室 AutoML 團(tuán)隊(duì)的NAS工作,論文題目:Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search。 針對(duì)現(xiàn)有DARTS框架在搜索階段訓(xùn)練過(guò)程中存在 skip-connection 富集現(xiàn)象,導(dǎo)致最終模型出現(xiàn)大幅度的性能損失問(wèn)題的問(wèn)題,提出了Sigmoid替代Softmax的方法,使搜索階段候選操作由競(jìng)爭(zhēng)關(guān)系轉(zhuǎn)化為合作關(guān)系。 并提出 0-1 loss 提高了架構(gòu)參數(shù)的二值性。
動(dòng)機(jī)
skip-connection 富集現(xiàn)象

本文指出 skip connections 富集的原因主要有兩個(gè)方面:
skip connections 的不公平優(yōu)勢(shì)
在超網(wǎng)絡(luò)訓(xùn)練架構(gòu)參數(shù)過(guò)程中,兩個(gè)節(jié)點(diǎn)之間是八個(gè)操作同時(shí)作用的, skip connections 作為操作的其中一員,相較于其他的操作來(lái)講是起到了跳躍連接的作用。在ResNet 中已經(jīng)明確指出了跳躍連接在深層網(wǎng)絡(luò)的訓(xùn)練過(guò)程中中起到了良好的梯度疏通效果,進(jìn)而有效減緩了梯度消失現(xiàn)象。因此,在超網(wǎng)絡(luò)的搜索訓(xùn)練過(guò)程中,skip connections可以借助其他操作的關(guān)系達(dá)到疏通效果,使得,skip connections 相較于其他操作存在不公平優(yōu)勢(shì)。
softmax 的排外競(jìng)爭(zhēng)
由于 softmax 是典型的歸一化操作,是一種潛在的排外競(jìng)爭(zhēng)方式,致使一個(gè)架構(gòu)參數(shù)增大必然抑制其他參數(shù)。
部署訓(xùn)練的離散化差異(discretization discrepancy)
搜索過(guò)程結(jié)束后,在部署訓(xùn)練選取網(wǎng)絡(luò)架構(gòu)時(shí),直接將 softmax 后最大 α 值對(duì)應(yīng)的操作保留而拋棄其它的操作,從而使得選出的網(wǎng)絡(luò)結(jié)構(gòu)和原始包含所有結(jié)構(gòu)的超網(wǎng)二者的表現(xiàn)能力存在差距。離散化差異問(wèn)題主要在于兩點(diǎn),一方面Softmax歸一化八種操作參數(shù)后,DARTS 最后選擇時(shí)的 α 值基本都在 0.1 到 0.3 之間,另一方面判定好壞的范圍比較窄,因?yàn)椴煌僮?α 值的 top1 和 top2 可能差距特別小,例如 0.26 和 0.24,很難說(shuō) 0.26 就一定比 0.24 好,如下圖所示:

方法
sigmoid 函數(shù)替換 softmax
class Network(nn.Module):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3,parse_method='darts', op_threshold=None):
pass
def forward(self, input):
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = F.sigmoid(self.alphas_reduce) # sigmoid 替換softmax
else:
weights = F.sigmoid(self.alphas_normal) # sigmoid 替換softmax
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
0-1 損失函數(shù)
l2 0-1 損失函數(shù)
l1 0-1 損失函數(shù)
# l2
class ConvSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L2"""
def __init__(self, weight=0.1, size_average=None, ignore_index=-100,reduce=None, reduction='mean'):
super(ConvSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
def forward(self, input1, target1, input2):
loss1 = F.cross_entropy(input1, target1)
loss2 = -F.mse_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
return loss1 + self.weight*loss2, loss1.item(), loss2.item()
# l1
class TriSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L1"""
def __init__(self, weight=0.1, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
super(TriSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
def forward(self, input1, target1, input2):
loss1 = F.cross_entropy(input1, target1)
loss2 = -F.l1_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
return loss1 + self.weight*loss2, loss1.item(), loss2.item()
使用上述損失函數(shù)就可以使得不同操作之間的差距增大,二者的 α 值要么逼近 0 要么逼近 1 如下圖曲線所示

實(shí)驗(yàn)
CIFAR-10
精度比較
FairDARTS 搜索 7 次均可得到魯棒性的結(jié)果:

skip connections 數(shù)量比較
DARTS 和 Fair DARTS 搜索出來(lái)的 cell 中所包含的 skip connections 數(shù)量比較:

ImageNet
精度比較
注意模型 A、B 是遷移比較,C、D 是直接搜索比較。

sigmoid 函數(shù)的共存性
熱力圖可看出使用 sigmoid 函數(shù)可讓其他操作和 skip connections 共存:


消融實(shí)驗(yàn)
去掉 Skip Connections
由于不公平的優(yōu)勢(shì)主要來(lái)自 Skip Connections,因此,搜索空間去掉 Skip Connections,那么即使在排他性競(jìng)爭(zhēng)中,其他操作也應(yīng)該期待公平競(jìng)爭(zhēng)。 去掉 Skip Connections搜索得到的最佳模型(96.88±0.18%)略高于DARTS(96.76±0.32%),但低于FairDARTS(97.41±0.14%)。 降低的精度表明足夠的 Skip Connections 確實(shí)對(duì)精度有益,因此也不能簡(jiǎn)單去掉。

0-1 損失函數(shù)分析

- 如果去掉 0-1 損失函數(shù)會(huì)使得 α 值不再集中于兩端,不利于離散化;
- 損失靈敏度,即通過(guò)超參來(lái)控制
損失函數(shù)的靈敏度
討論
對(duì)于 skip connections 使用 dropout 可以減少了不公平性;
對(duì)所有操作使用 dropout 同樣是有幫助的;
早停機(jī)制同樣關(guān)鍵(相當(dāng)于是在不公平出現(xiàn)以前及時(shí)止損);
限制 skip connections 的數(shù)量需要極大的人為先驗(yàn),因?yàn)橹灰薅?skip connections 的數(shù)量為 2,隨機(jī)搜索也能獲得不錯(cuò)的結(jié)果;
高斯噪聲或許也能打破不公平優(yōu)勢(shì)(孕育出了后面的NoisyDARTS~)。
參考
[1] Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search
[2] DARTS+: Improved Differentiable Architecture Search with Early Stopping
[3] Noisy Differentiable Architecture Search
[4] Fair DARTS:公平的可微分神經(jīng)網(wǎng)絡(luò)搜索
[5] Fair darts代碼解析