這幾天學(xué)習(xí)了一下softmax激活函數(shù),以及它的梯度求導(dǎo)過(guò)程,整理一下便于分享和交流!
一、softmax函數(shù)
softmax用于多分類過(guò)程中,它將多個(gè)神經(jīng)元的輸出,映射到(0,1)區(qū)間內(nèi),可以看成概率來(lái)理解,從而來(lái)進(jìn)行多分類!
假設(shè)我們有一個(gè)數(shù)組[a1, a2, a3 ... an],則
Si= e^ai / sum(e^a1 + e^a2 + ... +e^an)
如下圖:

softmax直白來(lái)說(shuō)就是將原來(lái)輸出是3,1,-3通過(guò)softmax函數(shù)一作用,就映射成為(0,1)的值,而這些值的累和為1(滿足概率的性質(zhì)),那么我們就可以將它理解成概率,在最后選取輸出結(jié)點(diǎn)的時(shí)候,我們就可以選取概率最大(也就是值對(duì)應(yīng)最大的)結(jié)點(diǎn),作為我們的預(yù)測(cè)目標(biāo)!
二、softmax相關(guān)求導(dǎo)
當(dāng)我們對(duì)分類的Loss進(jìn)行改進(jìn)的時(shí)候,我們要通過(guò)梯度下降,每次優(yōu)化一個(gè)step大小的梯度,這個(gè)時(shí)候我們就要求Loss對(duì)每個(gè)權(quán)重矩陣的偏導(dǎo),然后應(yīng)用鏈?zhǔn)椒▌t。那么這個(gè)過(guò)程的第一步,就是對(duì)softmax求導(dǎo)傳回去,不用著急,我后面會(huì)舉例子非常詳細(xì)的說(shuō)明。在這個(gè)過(guò)程中,你會(huì)發(fā)現(xiàn)用了softmax函數(shù)之后,梯度求導(dǎo)過(guò)程非常非常方便!

我們能得到下面公式:
z4 = w41o1+w42o2+w43*o3
z5 = w51o1+w52o2+w53*o3
z6 = w61o1+w62o2+w63*o3
z4,z5,z6分別代表結(jié)點(diǎn)4,5,6的輸出,01,02,03代表是結(jié)點(diǎn)1,2,3往后傳的輸入.
那么我們可以經(jīng)過(guò)softmax函數(shù)得到
a4= e^z4 / sum(e^z4 + e^z5 + e^z6)
a5= e^z5 / sum(e^z4 + e^z5 + e^z6)
a6= e^z6 / sum(e^z4 + e^z5 + e^z6)
好了,我們的重頭戲來(lái)了,怎么根據(jù)求梯度,然后利用梯度下降方法更新梯度!**
要使用梯度下降,肯定需要一個(gè)損失函數(shù),這里我們使用交叉熵作為我們的損失函數(shù),為什么使用交叉熵?fù)p失函數(shù),不是這篇文章重點(diǎn),后面有時(shí)間會(huì)單獨(dú)寫一下為什么要用到交叉熵函數(shù)(這里我們默認(rèn)選取它作為損失函數(shù))
交叉熵函數(shù)形式如下:
loss = -∑yi * log ai
其中y代表我們的真實(shí)值,a代表我們softmax求出的值。i代表的是輸出結(jié)點(diǎn)的標(biāo)號(hào)!在上面例子,i就可以取值為4,5,6三個(gè)結(jié)點(diǎn)(當(dāng)然我這里只是為了簡(jiǎn)單,真實(shí)應(yīng)用中可能有很多結(jié)點(diǎn))
現(xiàn)在看起來(lái)是不是感覺(jué)復(fù)雜了,居然還有累和,然后還要求導(dǎo),每一個(gè)a都是softmax之后的形式!
但是實(shí)際上不是這樣的,我們往往在真實(shí)中,如果只預(yù)測(cè)一個(gè)結(jié)果,那么在目標(biāo)中只有一個(gè)結(jié)點(diǎn)的值為1,比如我認(rèn)為在該狀態(tài)下,我想要輸出的是第四個(gè)動(dòng)作(第四個(gè)結(jié)點(diǎn)),那么訓(xùn)練數(shù)據(jù)的輸出就是a4 = 1,a5=0,a6=0,哎呀,這太好了,除了一個(gè)為1,其它都是0,那么所謂的求和符合,就是一個(gè)幌子,我可以去掉啦!
為了形式化說(shuō)明,我這里認(rèn)為訓(xùn)練數(shù)據(jù)的真實(shí)輸出為第j個(gè)為1,其它均為0!
那么Loss就變成了loss = -yi * log ai,累和已經(jīng)去掉了,太好了?,F(xiàn)在我們要開(kāi)始求導(dǎo)數(shù)了!
我們?cè)谡硪幌律厦婀剑瑸榱烁用靼椎目闯鱿嚓P(guān)變量的關(guān)系:yj = 1, loss = - log ai
那么形式越來(lái)越簡(jiǎn)單了,求導(dǎo)分析如下:
參數(shù)的形式在該例子中,總共分為w41,w42,w43,w51,w52,w53,w61,w62,w63.這些,那么比如我要求出w41,w42,w43的偏導(dǎo),就需要將Loss函數(shù)求偏導(dǎo)傳到結(jié)點(diǎn)4,然后再利用鏈?zhǔn)椒▌t繼續(xù)求導(dǎo)即可,舉個(gè)例子此時(shí)求w41的偏導(dǎo)為:

w51.....w63等參數(shù)的偏導(dǎo)同理可以求出,那么我們的關(guān)鍵就在于Loss函數(shù)對(duì)于結(jié)點(diǎn)4,5,6的偏導(dǎo)怎么求,如下:
這里分為倆種情況:




三、softmax實(shí)現(xiàn)(python)

