PC-DARTS原理及源碼解析
華為發(fā)表在ICLR 2020上的NAS工作,針對(duì)現(xiàn)有DARTS模型訓(xùn)練時(shí)需要 Large memory and computing 問(wèn)題,提出了 Partial Channel Connection 和 Edge Normalization 的技術(shù),在搜索過(guò)程中更快更好。
- aper: PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search
- Code: https://github.com/yuhuixu1993/PC-DARTS
動(dòng)機(jī)
接著上面的P-DARTS來(lái)看,盡管上面可以在17 cells情況下單卡完成搜索,但妥協(xié)犧牲的是operation的數(shù)量,這明顯不是個(gè)優(yōu)秀的方案,故此文 Partially-Connected DARTS,致力于大規(guī)模節(jié)省計(jì)算量和memory,從而進(jìn)行快速且大batchsize的搜索。
貢獻(xiàn)點(diǎn)
設(shè)計(jì)了基于channel的sampling機(jī)制,故每次只有小部分1/K channel的node來(lái)進(jìn)行operation search,減少了(K-1)/K 的memory,故batchsize可增大為K倍。
為了解決上述channel采樣導(dǎo)致的不穩(wěn)定性,提出了 邊緣正規(guī)化(edge normalization),在搜索時(shí)通過(guò)學(xué)習(xí)edge-level超參來(lái)減少不確定性。
方法

部分通道連接(Partial Channel Connection)
如上圖的上半部分,在所有的通道數(shù)K里隨機(jī)采樣 1/K 出來(lái),進(jìn)行 operation search,然后operation 混合后的結(jié)果與剩下的 (K-1)/K 通道數(shù)進(jìn)行 concat,公式表示如下:
上述的“部分通道連接”操作會(huì)帶來(lái)一些正副作用:
- 正作用:能減少operations選擇時(shí)的biases,弱化無(wú)參的子操作(Pooling, Skip-Connect)的作用。文中3.3節(jié)有這么一句話:當(dāng)proxy dataset非常難時(shí)(即ImageNet),往往一開(kāi)始都會(huì)累積很大權(quán)重在weight-free operation,故制約了其在ImageNet上直接搜索的性能。
- 副作用:由于網(wǎng)絡(luò)架構(gòu)在不同iterations優(yōu)化是基于隨機(jī)采樣的channels,故最優(yōu)的edge連通性將會(huì)不穩(wěn)定。
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
self.mp = nn.MaxPool2d(2,2)
for primitive in PRIMITIVES:
op = OPS[primitive](C //4, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C //4, affine=False))
self._ops.append(op)
def forward(self, x, weights):
#channel proportion k=4(實(shí)驗(yàn)證明1/4性能最佳)
dim_2 = x.shape[1]
xtemp = x[ : , : dim_2//4, :, :] # channel 0到1/4的輸入
xtemp2 = x[ : , dim_2//4:, :, :] # channel 1/4到1的輸入
temp1 = sum(w * op(xtemp) for w, op in zip(weights, self._ops)) # 僅1/4數(shù)據(jù)參與ops運(yùn)算
#reduction cell 需要在concat之前添加pooling操作
if temp1.shape[2] == x.shape[2]:
ans = torch.cat([temp1,xtemp2],dim=1)
else:
ans = torch.cat([temp1,self.mp(xtemp2)], dim=1)
ans = channel_shuffle(ans,4) # 一個(gè)cell完成后對(duì)channel進(jìn)行隨機(jī)打散,為下個(gè)cell做采樣準(zhǔn)備
#ans = torch.cat([ans[ : , dim_2//4:, :, :],ans[ : , : dim_2//4, :, :]],dim=1)
#except channe shuffle, channel shift also works
return ans
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape [batchsize, num_channels, height, width]
# -> [batchsize, groups,channels_per_group, height, width]
x = x.view(batchsize, groups,
channels_per_group, height, width)
# 打亂channel的操作(借助transpose后數(shù)據(jù)塊的stride發(fā)生變化,然后將其連續(xù)化)
# 參考:https://www.cnblogs.com/aoru45/p/10974508.html
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
邊緣正規(guī)化(Edge Normalization)
為了克服部分通道連接這個(gè)副作用,提出邊緣正規(guī)化(見(jiàn)上圖的下半部分),即把多個(gè)PC后的node輸入softmax權(quán)值疊加,類attention機(jī)制
由于edge 超參 在訓(xùn)練階段是共享的,故學(xué)習(xí)到的網(wǎng)絡(luò)更少依賴于不同iterations間的采樣到的channels,使得網(wǎng)絡(luò)搜索過(guò)程更穩(wěn)定。當(dāng)網(wǎng)絡(luò)搜索完畢,node間的operation選擇由operation-level和edge-level的參數(shù)相乘后共同決定。
weights_normal = [F.softmax(alpha, dim=-1) for alpha in alpha_normal]
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in alpha_reduce]
weights_edge_normal = [F.softmax(beta, dim=0) for beta in beta_normal]
weights_edge_reduce = [F.softmax(beta, dim=0) for beta in beta_reduce]
def parse(alpha, beta, k):
...
for edges, w in zip(alpha, beta):
edge_max, primitive_indices = torch.topk((w.view(-1, 1) * edges)[:, :-1], 1) # ignore 'none'
...
實(shí)驗(yàn)結(jié)果
CIFAR-10

ImageNet

消融實(shí)驗(yàn)

參考
[1] Yuhui Xu et al. ,PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search