NEURAL ARCHITECTURE SEARCH WITH REINFORCEMENT LEARNING 筆記

這是一篇使用增強(qiáng)學(xué)習(xí)來進(jìn)行模型搜索的論文。
結(jié)構(gòu)如下圖:


overview

由于不知道網(wǎng)絡(luò)的長度和結(jié)構(gòu),作者使用了一個(gè)RNN作為控制器,使用該控制器來產(chǎn)生一串信息,用于構(gòu)建網(wǎng)絡(luò)。之后訓(xùn)練該網(wǎng)絡(luò),并用網(wǎng)絡(luò)的accuracy作為reward返回給控制器來更新控制器的參數(shù),達(dá)到更優(yōu)的策略。
其中控制器(RNN)的設(shè)計(jì)借鑒了sequence to sequence的思想,不同的是它優(yōu)化的是一個(gè)不可微的目標(biāo),也就是 網(wǎng)絡(luò)的accuracy。

方法

CNN

上圖展示了如何使用RNN控制器產(chǎn)生一個(gè)簡單的CNN網(wǎng)絡(luò),對于CNN網(wǎng)絡(luò)的每一層,控制器都會產(chǎn)生一組超參數(shù),當(dāng)層數(shù)達(dá)到一個(gè)閾值,就會停止。RNN的參數(shù)\theta_{t}會通過增強(qiáng)學(xué)習(xí)算法更新,以得到更好的模型結(jié)構(gòu)。

使用REINFORCE來訓(xùn)練

控制器可以看作agent,控制器產(chǎn)生一組token,也就是超參數(shù),看作agent的action,使用產(chǎn)生的模型在驗(yàn)證集的準(zhǔn)確率作為reward。因此,控制器需要優(yōu)化下面公式:

optimization target

但是
R
是不可微分的,因此不能使用傳統(tǒng)的BP算法,在論文中,作者使用了REINFORCE。該算法是增強(qiáng)學(xué)習(xí)的常用算法之一,算法將agent的policy看作一個(gè)函數(shù),通過reward來進(jìn)行參數(shù)的更新,從而實(shí)現(xiàn)reward的最優(yōu)化。并且該算法給出了rewardpolicy參數(shù)的導(dǎo)數(shù)公式。
REINFORCE

它的一個(gè)經(jīng)驗(yàn)近似公式如下:


empirical approximation

m是一個(gè)batch中的模型個(gè)數(shù)
T是超參數(shù)的個(gè)數(shù)

由于以上公式會遇到variance過大的問題,可以使用如下帶baseline的公式


with baseline

分布式訓(xùn)練加速

分布式框架如下圖

PS

思路:其中 parameter server共同保存了控制器的所有參數(shù),這些server將參數(shù)分發(fā)給controller,每一個(gè)controller使用得到的參數(shù)進(jìn)行模型的構(gòu)建,這里由于得到的參數(shù)可能不同,構(gòu)建模型的策略是隨機(jī)的,導(dǎo)致每次構(gòu)建的網(wǎng)絡(luò)結(jié)構(gòu)也會不同。每個(gè)controller會構(gòu)建一個(gè)batch,也就是
m
個(gè)網(wǎng)絡(luò),然后并行地訓(xùn)練這些網(wǎng)絡(luò),得到它們的accuracy。也就是說,每一個(gè)controller會得到一個(gè)batch也就是
m
個(gè)網(wǎng)絡(luò),和它們的accuracy,然后根據(jù)之前提到的公式,計(jì)算參數(shù)的梯度。接著,計(jì)算完梯度的controller會將梯度發(fā)送給servers。這些server在得到梯度后,分別對自己負(fù)責(zé)的參數(shù)進(jìn)行更新。更新后,當(dāng)controller再次訓(xùn)練時(shí),會得到更新后的參數(shù)。這里如果每個(gè)controller各自發(fā)送自己的梯度,之間不進(jìn)行同步,就是異步更新。

skip connection and other layer types

為了能夠在搜索空間中加入類似于resnet,inception的skip connection。作者設(shè)計(jì)了anchor,用于表示是否和前面幾個(gè)層進(jìn)行連接。如下


anchor

相對應(yīng)的,agent的action選擇如下圖:


agent action

它會根據(jù)前面算出的概率P和自己的策略,判斷是否加入connection。
最終,所有沒有后續(xù)連接的層都會被連接到輸出層,如果連接的兩個(gè)層大小不一致,就將小的層用0來填充。(pad with zeros)

產(chǎn)生RNN網(wǎng)絡(luò)cell

為了產(chǎn)生RNN網(wǎng)絡(luò)cell,類似于LSTM,作者使用了一種樹的結(jié)構(gòu),每一個(gè)樹的節(jié)點(diǎn)都會擁有一個(gè)操作(addition, elementwise multiplication, etc.)和一個(gè)激活函數(shù)(tanh, sigmoid等)。每一個(gè)節(jié)點(diǎn)的輸入,都連接了兩個(gè)其他節(jié)點(diǎn)的輸出。為了使用上面描述的方法,作者將每個(gè)節(jié)點(diǎn)編號,按照順序預(yù)測。如下圖:


RNN

根據(jù)預(yù)測的結(jié)果,將會按照如下方式構(gòu)建網(wǎng)絡(luò):


Computation steps

總結(jié)

這篇文章將增強(qiáng)學(xué)習(xí)的算法應(yīng)用在了模型預(yù)測上,并且巧妙地使用RNN來預(yù)測參數(shù)??傮w思路依舊是通過在一個(gè)有限的搜索空間進(jìn)行高效的搜索,來不斷提高agent預(yù)測的模型的準(zhǔn)確率。
note:REINFORCE算法真神奇,能夠直接使用一個(gè)簡單的標(biāo)量reward來知道agent更新參數(shù)的方式。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容