這是一篇使用增強(qiáng)學(xué)習(xí)來進(jìn)行模型搜索的論文。
結(jié)構(gòu)如下圖:

由于不知道網(wǎng)絡(luò)的長度和結(jié)構(gòu),作者使用了一個(gè)RNN作為控制器,使用該控制器來產(chǎn)生一串信息,用于構(gòu)建網(wǎng)絡(luò)。之后訓(xùn)練該網(wǎng)絡(luò),并用網(wǎng)絡(luò)的accuracy作為reward返回給控制器來更新控制器的參數(shù),達(dá)到更優(yōu)的策略。
其中控制器(RNN)的設(shè)計(jì)借鑒了sequence to sequence的思想,不同的是它優(yōu)化的是一個(gè)不可微的目標(biāo),也就是 網(wǎng)絡(luò)的accuracy。
方法

上圖展示了如何使用RNN控制器產(chǎn)生一個(gè)簡單的CNN網(wǎng)絡(luò),對于CNN網(wǎng)絡(luò)的每一層,控制器都會產(chǎn)生一組超參數(shù),當(dāng)層數(shù)達(dá)到一個(gè)閾值,就會停止。RNN的參數(shù)會通過增強(qiáng)學(xué)習(xí)算法更新,以得到更好的模型結(jié)構(gòu)。
使用REINFORCE來訓(xùn)練
控制器可以看作agent,控制器產(chǎn)生一組token,也就是超參數(shù),看作agent的action,使用產(chǎn)生的模型在驗(yàn)證集的準(zhǔn)確率作為reward。因此,控制器需要優(yōu)化下面公式:

但是

它的一個(gè)經(jīng)驗(yàn)近似公式如下:

是一個(gè)batch中的模型個(gè)數(shù)
是超參數(shù)的個(gè)數(shù)
由于以上公式會遇到variance過大的問題,可以使用如下帶baseline的公式

分布式訓(xùn)練加速
分布式框架如下圖

思路:其中 parameter server共同保存了控制器的所有參數(shù),這些server將參數(shù)分發(fā)給controller,每一個(gè)controller使用得到的參數(shù)進(jìn)行模型的構(gòu)建,這里由于得到的參數(shù)可能不同,構(gòu)建模型的策略是隨機(jī)的,導(dǎo)致每次構(gòu)建的網(wǎng)絡(luò)結(jié)構(gòu)也會不同。每個(gè)controller會構(gòu)建一個(gè)batch,也就是
skip connection and other layer types
為了能夠在搜索空間中加入類似于resnet,inception的skip connection。作者設(shè)計(jì)了anchor,用于表示是否和前面幾個(gè)層進(jìn)行連接。如下

相對應(yīng)的,agent的action選擇如下圖:

它會根據(jù)前面算出的概率和自己的策略,判斷是否加入connection。
最終,所有沒有后續(xù)連接的層都會被連接到輸出層,如果連接的兩個(gè)層大小不一致,就將小的層用0來填充。(pad with zeros)
產(chǎn)生RNN網(wǎng)絡(luò)cell
為了產(chǎn)生RNN網(wǎng)絡(luò)cell,類似于LSTM,作者使用了一種樹的結(jié)構(gòu),每一個(gè)樹的節(jié)點(diǎn)都會擁有一個(gè)操作(addition, elementwise multiplication, etc.)和一個(gè)激活函數(shù)(tanh, sigmoid等)。每一個(gè)節(jié)點(diǎn)的輸入,都連接了兩個(gè)其他節(jié)點(diǎn)的輸出。為了使用上面描述的方法,作者將每個(gè)節(jié)點(diǎn)編號,按照順序預(yù)測。如下圖:

根據(jù)預(yù)測的結(jié)果,將會按照如下方式構(gòu)建網(wǎng)絡(luò):

總結(jié)
這篇文章將增強(qiáng)學(xué)習(xí)的算法應(yīng)用在了模型預(yù)測上,并且巧妙地使用RNN來預(yù)測參數(shù)??傮w思路依舊是通過在一個(gè)有限的搜索空間進(jìn)行高效的搜索,來不斷提高agent預(yù)測的模型的準(zhǔn)確率。
note:REINFORCE算法真神奇,能夠直接使用一個(gè)簡單的標(biāo)量reward來知道agent更新參數(shù)的方式。