在NLP領(lǐng)域,在神經(jīng)網(wǎng)絡(luò)興起之前,條件隨機(jī)場(chǎng)(CRF)一直是作為主力模型的存在,就算是在RNN系(包括BERT系)的模型興起之后,也通常會(huì)在模型的最后添加一個(gè)CRF層,以提高準(zhǔn)確率。因此,CRF是所有NLPer必須要精通且掌握的一個(gè)模型,本文將優(yōu)先闡述清楚與CRF有關(guān)的全部基本概念,并詳細(xì)對(duì)比HMM,最后獻(xiàn)上BI-LSTM+CRF的實(shí)戰(zhàn)代碼及理解。相信讀完本文,將對(duì)CRF的認(rèn)識(shí)有一個(gè)新的高度。
在閱讀本文之前,務(wù)必對(duì)概率圖模型基礎(chǔ)有一個(gè)全盤的掌握,若對(duì)此沒(méi)有信心的,可以先參考我之前的一篇總結(jié)文:概率圖模型基礎(chǔ)
本文的基本目錄如下:
基礎(chǔ)知識(shí)
1.1 CRF到底是什么?
1.2 如何用CRF建模?
1.3 CRF與HMM的區(qū)別是什么?BILSTM+CRF實(shí)戰(zhàn)
2.1 為什么需要添加CRF層?
2.2 如何計(jì)算損失函數(shù)?
2.3 實(shí)戰(zhàn)環(huán)節(jié)
------------------第一菇 - 基礎(chǔ)知識(shí)------------------
1.1 CRF到底是什么?
本段主要用于講述與CRF有關(guān)的基礎(chǔ)概念。
大部分人理解CRF都會(huì)被帶到一個(gè)奇怪的誤區(qū)里面(包括我之前),因?yàn)榭偸抢斫馔炅薍MM以后,就會(huì)立馬投入到CRF的學(xué)習(xí)里面,所以就會(huì)理所當(dāng)然的認(rèn)為CRF就是HMM的升級(jí)版(確實(shí)從模型效果上可以這么理解),然后一直把HMM的各自概念往CRF上套,之后兩廂一對(duì)比,就會(huì)有點(diǎn)犯迷糊了,好多東西也對(duì)不上啊??~然后,再一看書里的結(jié)論,啥?CRF竟然是判別式模型,HMM是生成式模型!這是什么鬼啦?CRF不就是HMM解除各自限制(有向圖變無(wú)向圖,箭頭指的方向更多了嗎??????),怎么突然就變成判別式模型啦?。繌U話不多說(shuō),如果有此疑問(wèn)的同學(xué),就說(shuō)明看對(duì)文章了,不要心急慢慢看,我將一一解釋清楚;而對(duì)此毫無(wú)疑問(wèn)的同學(xué),那可以直接跳到實(shí)戰(zhàn)環(huán)節(jié)了哈。
CRF真的是判別式模型!準(zhǔn)確說(shuō)是,判別式無(wú)向圖模型!
大家要牢記,區(qū)別判別式模型與生成式模型最基本的就是去判斷模型是對(duì)聯(lián)合分布進(jìn)行建模,還是對(duì)條件分布進(jìn)行建模。HMM中很顯然,模型是對(duì)的聯(lián)合分布進(jìn)行建模(不清楚的同學(xué),還請(qǐng)移步HMM專區(qū)),而CRF則不然,其試圖對(duì)多個(gè)變量在給定觀測(cè)值后的條件概率進(jìn)行建模,因此屬于判別式模型。(各位抱著學(xué)新東西的心態(tài)來(lái)學(xué)CRF,把HMM拋在腦后把)
具體展開(kāi)來(lái)看,若令為觀測(cè)序列,
為與之對(duì)應(yīng)的標(biāo)記序列,則條件隨機(jī)場(chǎng)的目標(biāo)是構(gòu)建條件概率模型
。值得注意的是,標(biāo)記變量
可以是結(jié)構(gòu)型變量,即其分量之間具有某種相關(guān)性。就比如在NLP領(lǐng)域中的詞性標(biāo)注任務(wù),觀測(cè)數(shù)據(jù)為單詞序列(即為
),標(biāo)記為相應(yīng)的詞性序列(即為
),且其具有線性序列結(jié)構(gòu)。
1.2 如何用CRF建模?
令表示結(jié)點(diǎn)與標(biāo)記變量
中元素一一對(duì)應(yīng)的無(wú)向圖,
表示與結(jié)點(diǎn)
對(duì)應(yīng)的標(biāo)記變量,
表示結(jié)點(diǎn)
的鄰接結(jié)點(diǎn),若圖
的每個(gè)變量
都滿足馬爾可夫性(即只與其相鄰的結(jié)點(diǎn)有關(guān)),即,
則構(gòu)成一個(gè)條件隨機(jī)場(chǎng)。而理論上來(lái)說(shuō),圖G可具有任意結(jié)構(gòu),只要能表示標(biāo)記變量之間的條件獨(dú)立性關(guān)系即可,但在現(xiàn)實(shí)應(yīng)用中,尤其是對(duì)標(biāo)記序列建模時(shí)候,最常用的仍然是鏈?zhǔn)浇Y(jié)構(gòu),即“鏈?zhǔn)綏l件隨機(jī)場(chǎng)(chain-structured CRF)”,也是我們接下來(lái)主要要討論的一種條件隨機(jī)場(chǎng)。

那我們?cè)撊绾味x呢?
參考《機(jī)器學(xué)習(xí)》第14章的原文,其定義的方式類似馬爾可夫隨機(jī)場(chǎng)模型定義的聯(lián)合概率。條件隨機(jī)場(chǎng)使用勢(shì)函數(shù)和圖結(jié)構(gòu)上的團(tuán)來(lái)定義條件概率!如上圖所示,該鏈?zhǔn)綏l件隨機(jī)場(chǎng)主要包含兩種關(guān)于標(biāo)記變量的團(tuán),即單個(gè)標(biāo)記變量
以及相鄰的標(biāo)記變量
。選擇合適的勢(shì)函數(shù),即可得到形如馬爾可夫隨機(jī)場(chǎng)中聯(lián)合概率的定義。
在條件隨機(jī)場(chǎng)中,通過(guò)選用指數(shù)勢(shì)函數(shù)并引入特征函數(shù),條件概率被定義如下,
注意哈,這里的指的是整一個(gè)觀測(cè)序列!而且這里定義的條件概率計(jì)算方式,就只是將觀測(cè)序列
作為條件,并不對(duì)其作任何獨(dú)立性假設(shè)?。。。ㄟ@點(diǎn)很重要!也是其是判別式模型的重要依據(jù))
其中,是定義在觀測(cè)序列的倆個(gè)相鄰標(biāo)記位置上的轉(zhuǎn)移特征函數(shù),用于刻畫相鄰標(biāo)記變量之間的相關(guān)關(guān)系以及觀測(cè)序列對(duì)他們的影響。即給定觀測(cè)序列
,其標(biāo)注序列在
及
位置上標(biāo)記的轉(zhuǎn)移概率!而特征函數(shù)的定義往往不止一種,因此會(huì)有一個(gè)下標(biāo)
代表要遍歷計(jì)算每一種特征函數(shù)的取值。
另外,是定義在觀測(cè)序列的標(biāo)記位置
上的狀態(tài)特征函數(shù),用于刻畫觀測(cè)序列對(duì)標(biāo)記變量的影響。即表示對(duì)于觀察序列
其
位置的標(biāo)記概率。同理,也有多種特征函數(shù),所以會(huì)有下標(biāo)
。
剩下的就比較簡(jiǎn)單理解,都是參數(shù),
為規(guī)范化因子,用于確保上式是被正確定義的概率(可以理解為類似softmax的操作)。
總結(jié)一下上式,可以理解為如下圖,

至此,整個(gè)概率的定義想必大家已經(jīng)爛熟于心了~顯然,要運(yùn)用好條件隨機(jī)場(chǎng),最重要的就是要去定義合適的特征函數(shù)了。特征函數(shù)通常是實(shí)值函數(shù),以刻畫數(shù)據(jù)的一些很可能成立或期望成立的經(jīng)驗(yàn)特性。因此定義特征函數(shù)的時(shí)候,一般都可以定義一組關(guān)于觀察序列的二值特征
來(lái)表示訓(xùn)練樣本中某些分布特性,比如詞性標(biāo)注任務(wù),
等等類似的特征函數(shù),能定義好多出來(lái)的。因此,小小總結(jié)一下,CRF與馬爾可夫隨機(jī)場(chǎng)均使用團(tuán)上的勢(shì)函數(shù)定義概率,兩者在形式上并沒(méi)有顯著區(qū)別,只不過(guò)CRF處理的是條件概率,而馬爾可夫隨機(jī)場(chǎng)處理的是聯(lián)合概率。至此,整個(gè)CRF的建模已經(jīng)講明白了。
1.3 CRF與HMM的區(qū)別是什么?
CRF與HMM的一些基本定義的概念區(qū)別這邊在講概率圖模型和上面的基礎(chǔ)定義時(shí)已經(jīng)表述的很清楚了,本段就不繼續(xù)展開(kāi)了~這里主要講一下HMM的標(biāo)注偏置問(wèn)題,以及CRF為何能解決這個(gè)問(wèn)題。
其實(shí)要想解釋清楚標(biāo)注偏置問(wèn)題,大家只要看如下貼的一張圖即可,

大家可以發(fā)現(xiàn),狀態(tài)1傾向于轉(zhuǎn)移到狀態(tài)2,狀態(tài)2傾向于轉(zhuǎn)移到狀態(tài)2本身,但是實(shí)際計(jì)算得到的最大概率路徑是1>1>1>1,狀態(tài)1并沒(méi)有轉(zhuǎn)移到狀態(tài)2!這其實(shí)是與我們的直覺(jué)相悖的~究其本質(zhì)原因,從狀態(tài)2轉(zhuǎn)移出去可能的狀態(tài)包括1,2,3,4,5,概率在可能的狀態(tài)上分散了,而狀態(tài)1轉(zhuǎn)移出去的可能狀態(tài)僅僅為狀態(tài)1和2,概率更加集中?。ù蠹铱梢阅霉P算一下,是不是這么個(gè)理~加深理解)由于局部歸一化的影響,隱狀態(tài)會(huì)傾向于轉(zhuǎn)移到那些后續(xù)狀態(tài)可能更少的狀態(tài)上,以提高整體的后驗(yàn)概率!這就是標(biāo)注偏置問(wèn)題!
而CRF如上所述,因?yàn)橛袣w一化因子(Z)的存在,其在全局范圍內(nèi)進(jìn)行了歸一化,枚舉了整個(gè)隱狀態(tài)序列的全部可能,從而解決了局部歸一化帶來(lái)的標(biāo)注偏置問(wèn)題。而這也是CRF在很多問(wèn)題上,表現(xiàn)比HMM優(yōu)秀的原因~
------------------第二菇 - BILSTM+CRF實(shí)戰(zhàn)------------------
介紹了這么多CRF有關(guān)的東西,想必各位也是躍躍欲試,我這邊也獻(xiàn)上一份BILSTM+CRF的實(shí)戰(zhàn)解析,包括對(duì)此模型架構(gòu)的理解以及源碼的解讀~
這里貼一個(gè)鏈接,是一個(gè)外國(guó)小哥寫的博客,當(dāng)初就是看這個(gè)博客明白其原理的,所以也特地在這邊貼出來(lái),英文好的同學(xué)也可以直接看這個(gè)鏈接,我就不單獨(dú)放在參考文獻(xiàn)里了。
為了便于理解,代碼都是用Pytorch寫的,且還是以命名實(shí)體識(shí)別任務(wù)為具體例子。
如果看到這篇文章的是初學(xué)者,也不用慌,就簡(jiǎn)單理解BILSTM和CRF為一個(gè)命名實(shí)體識(shí)別模型中的兩個(gè)層。
為了便于理解下面的圖示,這邊假設(shè)我們的數(shù)據(jù)集有兩大類,人名和地名,與之相對(duì)應(yīng)在我們的訓(xùn)練數(shù)據(jù)集中,有五類標(biāo)簽:
* B-Person
* I-Person
* B-Organization
* I-Organization
* O
假設(shè)句子由5個(gè)字符組成,即
,其中
為人名實(shí)體,
為組織實(shí)體,其他字符的標(biāo)簽為"O"。
2.1 為什么需要添加CRF層?
這里先直接貼一張BILSTM-CRF的模型結(jié)構(gòu)圖,方便大家理解。

從下往上看,最下面就是輸入(字或詞向量),由于是序列模型,因此,在“時(shí)間”緯度上進(jìn)行展開(kāi),就可以得到如圖所示的模型表示,對(duì)應(yīng)于一個(gè)時(shí)刻,就是輸入一個(gè)字/詞向量(一般都是預(yù)先訓(xùn)練得出的)。
首先,是經(jīng)過(guò)BiLSTM的結(jié)構(gòu)單元。這個(gè)比較好理解,本質(zhì)上就是倆個(gè)LSTM層,只不過(guò)一次是正序輸入,一次是倒序輸入,然后把倆個(gè)結(jié)果進(jìn)行concact(拼接),并輸入到CRF層,最后由CRF層輸出每一個(gè)詞的標(biāo)簽~如果沒(méi)有CRF層的話,傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)都會(huì)加一層softmax層用于歸一化并輸出每個(gè)標(biāo)簽概率。
為了更容易理解CRF層的作用,我們還是先要理清Bi-LSTM的輸出。這里再貼一張圖,方便大家理解,

大家可以看到,其輸出十分簡(jiǎn)單清晰,就是對(duì)于每一個(gè)單詞,其對(duì)應(yīng)每一個(gè)標(biāo)簽的分值(score)。因此,就算沒(méi)有CRF層,該模型依舊有效,我們只需要挑選每一個(gè)標(biāo)簽對(duì)應(yīng)最大的分值就可以,比如,就是
。因此,在原有模型本身就有效的情況下,我們?cè)偬砑右粚覥RF的目的肯定只有一個(gè),即提高模型的準(zhǔn)確率。
接下來(lái),我們就要重點(diǎn)分析一下,CRF層的作用。先上結(jié)論,CRF層的主要作用是為最后預(yù)測(cè)的標(biāo)簽添加一些約束來(lái)保證預(yù)測(cè)標(biāo)簽的合理性!比如,在命名實(shí)體識(shí)別任務(wù)中,我們可以想到的約束可以是,
1)開(kāi)頭的標(biāo)簽只能是B, O,而不可能是I-
2)B-Person開(kāi)頭的標(biāo)簽,后面不可能接一個(gè)I-Organization
。。。
有了如上約束,我們就能保證,最終預(yù)測(cè)生成的標(biāo)簽序列的不合理性就會(huì)大大降低,而單憑BiLSTM的輸出來(lái)預(yù)測(cè)是無(wú)法保證標(biāo)簽序列的合理性的~
2.2 損失函數(shù)的定義及計(jì)算
弄清楚了CRF層的作用以后,我們就要來(lái)仔細(xì)研究研究CRF層的運(yùn)行原理了,主要從其損失函數(shù)的角度來(lái)理解。
2.2.1 CRF中的兩種分?jǐn)?shù)
在CRF層的損失函數(shù)中,有兩種形式的score(分?jǐn)?shù)),第一個(gè)就是emission score(發(fā)射分?jǐn)?shù)),主要就是來(lái)自于BiLSTM層的輸出(如上圖所示),假設(shè)我們給每一個(gè)標(biāo)簽一個(gè)索引,那么第一個(gè)單詞的emission score就是,~
注:大家可千萬(wàn)注意了,不要把這里的score和HMM里面的發(fā)射概率矩陣相混淆!兩者可是完全不一樣的,在HMM中,每一個(gè)單詞的發(fā)射概率,僅與當(dāng)前的隱狀態(tài)層有關(guān)!是由隱狀態(tài)決定了當(dāng)前單詞的發(fā)射概率!而這里的發(fā)射分?jǐn)?shù),是由當(dāng)前輸入的序列,決定的當(dāng)前狀態(tài)的概率!這里是整一個(gè)序列哦!若大家還能記得CRF層的定義和概率模型圖(往上翻一翻),想必對(duì)此并不會(huì)驚訝!而這,也是CRF層能接在神經(jīng)網(wǎng)絡(luò)最后一層的主要原因!大家對(duì)此一定要有深刻的理解和認(rèn)識(shí)。
第二個(gè)就是transition score(轉(zhuǎn)移分?jǐn)?shù)),這個(gè)倒是跟HMM中的狀態(tài)概率轉(zhuǎn)移矩陣相類似,也很好理解,也是模型中主要學(xué)習(xí)的參數(shù)!而且為了使模型更具有魯棒性,我們額外增加了倆個(gè)標(biāo)簽,和
,
代表句子的開(kāi)始位置,而非第一個(gè)詞,同理
代表句子的結(jié)束位置,這里也貼一張transition score矩陣的圖,方便大家理解,

大家從圖中應(yīng)該很清楚可以看到,其能學(xué)到很多約束規(guī)則!因此,該矩陣也是模型主要訓(xùn)練的一個(gè)參數(shù),一般一開(kāi)始都會(huì)初始化一個(gè)概率轉(zhuǎn)移矩陣,隨著訓(xùn)練的迭代,逐漸合理~因此,接下來(lái),我們就要來(lái)看看,其損失函數(shù)是如何設(shè)計(jì)的,才能學(xué)到合理的參數(shù)~
2.2.2 損失函數(shù)的設(shè)計(jì)
先明確一點(diǎn),損失函數(shù)就是我們要優(yōu)化的目標(biāo),那對(duì)于這樣一個(gè)序列標(biāo)注問(wèn)題,我們肯定是希望,正確的序列,是所有的可能序列中,得分最高的!就如同我們作HMM解碼的時(shí)候,利用維特比算法解碼,我們返回的肯定是概率最大的那條路徑一般,那反過(guò)來(lái),我們訓(xùn)練的時(shí)候,自然希望得到的參數(shù),能使得正確路徑的概率最大~
有了上述的核心思想,我們?cè)賮?lái)想一想,如何求解。假設(shè)一共有種可能的標(biāo)簽序列組合,記第i個(gè)標(biāo)簽序列的得分為
,那么所有可能標(biāo)簽序列組合的總得分為,
因此,我們可以設(shè)想出一個(gè)損失函數(shù),就是真實(shí)序列的分?jǐn)?shù)在所有可能的序列中占比最高,即,
由這個(gè)損失函數(shù),引出2個(gè)問(wèn)題,
1)如何定義計(jì)算每一個(gè)序列的得分?
2)如何計(jì)算所有標(biāo)簽序列的總得分?
2.2.3 求解一個(gè)序列的得分
先來(lái)看第一個(gè)問(wèn)題,上述曾提過(guò)倆個(gè)分?jǐn)?shù)概念,emission score和transition score,因此,一個(gè)序列的得分也有這倆個(gè)構(gòu)成。
看到這個(gè)公式,大家再與CRF定義相聯(lián)系,有木有看出點(diǎn)什么花頭?沒(méi)錯(cuò),這個(gè)跟CRF定義的條件概率幾乎是類似的哦,基本上可以理解為,BiLSTM的輸出(也就是EmissionScore)取代來(lái)狀態(tài)特征函數(shù)的位置,而我們要學(xué)習(xí)的參數(shù)也就是轉(zhuǎn)移特征函數(shù)及其權(quán)重。所有,CRF與神經(jīng)網(wǎng)絡(luò)的配套組合并不是強(qiáng)行加上或者巧合,而是有理論作強(qiáng)支撐的哈哈~
我們逐一來(lái)理解每一個(gè)分?jǐn)?shù)的計(jì)算過(guò)程,假設(shè)我們有一個(gè)正確的序列標(biāo)注為,
那么,
其中,就表示第index個(gè)詞被標(biāo)記為label的得分(直接是從神經(jīng)網(wǎng)絡(luò)的輸出能拿到的)。
而另一個(gè)轉(zhuǎn)移分?jǐn)?shù)即為,
其中,就表示label1 到 label2的轉(zhuǎn)移概率,也就是模型要學(xué)習(xí)的參數(shù)~
2.2.4 計(jì)算所有序列總分?jǐn)?shù)的方法
至此,每一條路徑的總得分,就可以根據(jù)上面的式子很輕松的計(jì)算得出~但顯然,如果真實(shí)計(jì)算也是如此遍歷操作的話,時(shí)間復(fù)雜度會(huì)吃不消的,因此我們需要一個(gè)高效的算法來(lái)計(jì)算~
我們先簡(jiǎn)化一下?lián)p失函數(shù),

簡(jiǎn)化之后,可以很輕松的看出,前半部分的計(jì)算是固定的,我們只需要高效的計(jì)算出后半部分即可,
很明顯,這里會(huì)運(yùn)用到動(dòng)態(tài)規(guī)劃的思想(不懂動(dòng)態(tài)規(guī)劃的,直接去看一下維特比算法,加深理解)來(lái)進(jìn)行求解,即利用的總得分來(lái)推出
的總得分,最后以此類推,每一次計(jì)算都需要利用到上一步計(jì)算得到的結(jié)果。
這里舉一個(gè)簡(jiǎn)單的示例,假設(shè)句子長(zhǎng)度為3(),標(biāo)簽有2個(gè)(
),我們學(xué)到的Emission Score 矩陣如下(BiLSTM輸出),
學(xué)習(xí)到的Transition Score矩陣如下,
接下來(lái),將演示如何計(jì)算總得分,因?yàn)槭莿?dòng)態(tài)規(guī)劃的思想,只需演繹出其中一步即可~
針對(duì),很輕松,因?yàn)闆](méi)有轉(zhuǎn)移分?jǐn)?shù),僅有發(fā)射分?jǐn)?shù),因此在第一個(gè)位置的總得分即為兩種路徑的分?jǐn)?shù)總和,而現(xiàn)在兩種路徑就是兩種可能性,要么就是標(biāo)簽1要么就是標(biāo)簽2,而這也是整個(gè)動(dòng)態(tài)規(guī)劃開(kāi)始的初始條件~)
接下來(lái),我們要求在第二個(gè)位置的總得分,注意我們的推導(dǎo)式子就是由
推出的
,因此我們直接利用
計(jì)算得出的分?jǐn)?shù)即可~如下圖所演示的~

上述的動(dòng)態(tài)核心式子即為,
最終,將在位置的所有狀態(tài)求和得分相加即是總路徑的得分~
有人可能會(huì)問(wèn),那你這不是只有一個(gè)位置的得分了嗎?我們不是要求得總路徑得分嗎?有這個(gè)疑惑的同學(xué)應(yīng)該就是還沒(méi)有領(lǐng)悟到動(dòng)態(tài)規(guī)劃的精髓,建議自己手動(dòng)推導(dǎo)一遍,便可迎刃而解~至此,整一個(gè)所有序列路徑的求和方法,我們已經(jīng)大致了解清楚了~(多提一句,預(yù)測(cè)階段的解碼思路,其實(shí)就是維特比算法,也是動(dòng)態(tài)規(guī)劃的思路,十分簡(jiǎn)單,這里就不多說(shuō)了~)
2.3 實(shí)戰(zhàn)環(huán)節(jié)
上面兩節(jié)已經(jīng)把BiLSTM+CRF講的清清楚楚了~光看理論還不夠,我們要深入代碼實(shí)戰(zhàn)環(huán)節(jié)(注:此乃網(wǎng)上找的一個(gè)Pytorch的版本,個(gè)人覺(jué)得是寫的比較好的,只限用于理解理論,并非商業(yè)應(yīng)用)
我們首先導(dǎo)入相應(yīng)的包和定義一些后面要用到的輔助函數(shù),如下,
import torch
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(1)
# some helper functions
def argmax(vec):
# return the argmax as a python int
# 第1維度上最大值的下標(biāo)
# input: tensor([[2,3,4]])
# output: 2
_, idx = torch.max(vec,1)
return idx.item()
def prepare_sequence(seq,to_ix):
# 文本序列轉(zhuǎn)化為index的序列形式
idxs = [to_ix[w] for w in seq]
return torch.tensor(idxs, dtype=torch.long)
def log_sum_exp(vec):
#compute log sum exp in a numerically stable way for the forward algorithm
# 用數(shù)值穩(wěn)定的方法計(jì)算正演算法的對(duì)數(shù)和exp
# input: tensor([[2,3,4]])
# max_score_broadcast: tensor([[4,4,4]])
max_score = vec[0, argmax(vec)]
max_score_broadcast = max_score.view(1,-1).expand(1,vec.size()[1])
return max_score+torch.log(torch.sum(torch.exp(vec-max_score_broadcast)))
START_TAG = "<s>"
END_TAG = "<e>"
這里定義的幾個(gè)輔助函數(shù)都比較直觀,唯獨(dú)log_sum_exp可能會(huì)對(duì)大家造成一點(diǎn)困擾,但其實(shí)這是一種考慮數(shù)值穩(wěn)定性的求解辦法,具體大家參考這篇博文即可,深究一下也是好事情,不深究的就明白這個(gè)函數(shù)是為了求即可~
我們接著看模型的定義,
# create model
class BiLSTM_CRF(nn.Module):
def __init__(self,vocab_size, tag2ix, embedding_dim, hidden_dim):
super(BiLSTM_CRF,self).__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.tag2ix = tag2ix
self.tagset_size = len(tag2ix)
self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim//2, num_layers=1, bidirectional=True)
# maps output of lstm to tog space
self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
# matrix of transition parameters
# entry i, j is the score of transitioning to i from j
# tag間的轉(zhuǎn)移矩陣,是CRF層的參數(shù)
self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
# these two statements enforce the constraint that we never transfer to the start tag
# and we never transfer from the stop tag
self.transitions.data[tag2ix[START_TAG], :] = -10000
self.transitions.data[:, tag2ix[END_TAG]] = -10000
self.hidden = self.init_hidden()
def init_hidden(self):
return (torch.randn(2, 1,self.hidden_dim//2),
torch.randn(2, 1,self.hidden_dim//2))
def _forward_alg(self, feats):
# to compute partition function
# 求歸一化項(xiàng)的值,應(yīng)用動(dòng)態(tài)歸化算法
init_alphas = torch.full((1,self.tagset_size), -10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
# START_TAG has all of the score
init_alphas[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])
forward_var = init_alphas
for feat in feats:
#feat指Bi-LSTM模型每一步的輸出,大小為tagset_size
alphas_t = []
for next_tag in range(self.tagset_size):
# 取其中的某個(gè)tag對(duì)應(yīng)的值進(jìn)行擴(kuò)張至(1,tagset_size)大小
# 如tensor([3]) -> tensor([[3,3,3,3,3]])
emit_score = feat[next_tag].view(1,-1).expand(1,self.tagset_size)
# 增維操作
trans_score = self.transitions[next_tag].view(1,-1)
# 上一步的路徑和+轉(zhuǎn)移分?jǐn)?shù)+發(fā)射分?jǐn)?shù)
next_tag_var = forward_var + trans_score + emit_score
# log_sum_exp求和
alphas_t.append(log_sum_exp(next_tag_var).view(1))
# 增維
forward_var = torch.cat(alphas_t).view(1,-1)
terminal_var = forward_var+self.transitions[self.tag2ix[END_TAG]]
alpha = log_sum_exp(terminal_var)
#歸一項(xiàng)的值
return alpha
def _get_lstm_features(self,sentence):
self.hidden = self.init_hidden()
embeds = self.word_embeds(sentence).view(len(sentence),1,-1)
lstm_out, self.hidden = self.lstm(embeds, self.hidden)
lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
lstm_feats = self.hidden2tag(lstm_out)
return lstm_feats
def _score_sentence(self,feats,tags):
# gives the score of a provides tag sequence
# 求某一路徑的值
score = torch.zeros(1)
tags = torch.cat([torch.tensor([self.tag2ix[START_TAG]], dtype=torch.long), tags])
for i , feat in enumerate(feats):
score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
score = score + self.transitions[self.tag2ix[END_TAG], tags[-1]]
return score
def _viterbi_decode(self, feats):
# 當(dāng)參數(shù)確定的時(shí)候,求解最佳路徑
backpointers = []
init_vars = torch.full((1,self.tagset_size),-10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
init_vars[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])
forward_var = init_vars
for feat in feats:
bptrs_t = [] # holds the back pointers for this step
viterbivars_t = [] # holds the viterbi variables for this step
for next_tag in range(self.tagset_size):
next_tag_var = forward_var + self.transitions[next_tag]
best_tag_id = argmax(next_tag_var)
bptrs_t.append(best_tag_id)
viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
backpointers.append(bptrs_t)
# Transition to STOP_TAG
terminal_var = forward_var + self.transitions[self.tag2ix[END_TAG]]
best_tag_id = argmax(terminal_var)
path_score = terminal_var[0][best_tag_id]
# Follow the back pointers to decode the best path.
best_path = [best_tag_id]
for bptrs_t in reversed(backpointers):
best_tag_id = bptrs_t[best_tag_id]
best_path.append(best_tag_id)
# Pop off the start tag (we dont want to return that to the caller)
start = best_path.pop()
assert start == self.tag2ix[START_TAG] # Sanity check
best_path.reverse()
return path_score, best_path
def neg_log_likelihood(self, sentence, tags):
# 由lstm層計(jì)算得的每一時(shí)刻屬于某一tag的值
feats = self._get_lstm_features(sentence)
# 歸一項(xiàng)的值
forward_score = self._forward_alg(feats)
# 正確路徑的值
gold_score = self._score_sentence(feats, tags)
return forward_score - gold_score# -(正確路徑的分值 - 歸一項(xiàng)的值)
def forward(self, sentence): # dont confuse this with _forward_alg above.
# Get the emission scores from the BiLSTM
lstm_feats = self._get_lstm_features(sentence)
# Find the best path, given the features.
score, tag_seq = self._viterbi_decode(lstm_feats)
return score, tag_seq
上面的注釋應(yīng)該說(shuō)是很詳細(xì)了,一開(kāi)始的初始化定義也都是Pytorch的常規(guī)寫法(簡(jiǎn)書的代碼顯示的略詭異,大家將就看看吧)~LSTM層也是直接掉的nn里面的,只有CRF層是自己手?jǐn)]上來(lái)的~所以,大家重點(diǎn)關(guān)注一下_forward_alg這個(gè)函數(shù),就是我們上面講的求解路徑總得分的函數(shù)~其中feats就是序列步長(zhǎng),自然是要順序遍歷每一個(gè)feat,其中每一個(gè)feat又要遍歷每一種tag的情況,利用forward_var記錄每一個(gè)路徑的總得分(實(shí)時(shí)更新),最后在求和即可!應(yīng)該說(shuō)看懂了上面的解釋的同學(xué),在看這個(gè)代碼,簡(jiǎn)直是太簡(jiǎn)單了哈哈~其他的函數(shù)也沒(méi)啥好特地強(qiáng)調(diào)的,大家掃一眼明白即可,對(duì)解碼不清楚的,直接看代碼也難,手動(dòng)推演一遍,理解的更快~
最后,我們?cè)賮?lái)看一下主函數(shù),
if __name__ == "__main__":
EMBEDDING_DIM = 5
HIDDEN_DIM = 4
# Make up some training data
training_data = [(
"the wall street journal reported today that apple corporation made money".split(),
"B I I I O O O B I O O".split()
), (
"georgia tech is a university in georgia".split(),
"B I O O O O B".split()
)]
word2ix = {}
for sentence, tags in training_data:
for word in sentence:
if word not in word2ix:
word2ix[word] = len(word2ix)
tag2ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, END_TAG: 4}
model = BiLSTM_CRF(len(word2ix), tag2ix, EMBEDDING_DIM, HIDDEN_DIM)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)
# Check predictions before training
# 輸出訓(xùn)練前的預(yù)測(cè)序列
with torch.no_grad():
precheck_sent = prepare_sequence(training_data[0][0], word2ix)
precheck_tags = torch.tensor([tag2ix[t] for t in training_data[0][1]], dtype=torch.long)
print(model(precheck_sent))
# Make sure prepare_sequence from earlier in the LSTM section is loaded
for epoch in range(300): # again, normally you would NOT do 300 epochs, it is toy data
for sentence, tags in training_data:
# Step 1. Remember that Pytorch accumulates gradients.
# We need to clear them out before each instance
model.zero_grad()
# Step 2. Get our inputs ready for the network, that is,
# turn them into Tensors of word indices.
sentence_in = prepare_sequence(sentence, word2ix)
targets = torch.tensor([tag2ix[t] for t in tags], dtype=torch.long)
# Step 3. Run our forward pass.
loss = model.neg_log_likelihood(sentence_in, targets)
# Step 4. Compute the loss, gradients, and update the parameters by
# calling optimizer.step()
loss.backward()
optimizer.step()
# Check predictions after training
with torch.no_grad():
precheck_sent = prepare_sequence(training_data[0][0], word2ix)
print(model(precheck_sent))
# 輸出結(jié)果
# (tensor(-9996.9365), [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
# (tensor(-9973.2725), [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])
也是比較常規(guī)的寫法,還帶了示例~大家應(yīng)該很容易理解的!
至此,整一套跟CRF有關(guān)的知識(shí)點(diǎn)和代碼解釋已經(jīng)全部弄清楚了。簡(jiǎn)單總結(jié)一下本文,先是詳細(xì)解釋了一下與CRF有關(guān)的一些誤區(qū)和知識(shí)點(diǎn),接著展示了與CRF有關(guān)的用法和計(jì)算損失函數(shù)的方法,最后獻(xiàn)上了詳細(xì)的代碼解讀~希望大家讀完本文后對(duì)CRF的一些概念會(huì)有一個(gè)全新的認(rèn)識(shí)。有說(shuō)的不對(duì)的地方也請(qǐng)大家指出,多多交流,大家一起進(jìn)步~??