[干貨]深入淺出LSTM及其Python代碼實現(xiàn)

人工神經(jīng)網(wǎng)絡(luò)在近年來大放異彩,在圖像識別、語音識別、自然語言處理與大數(shù)據(jù)分析領(lǐng)域取得了巨大的成功,而長短期記憶網(wǎng)絡(luò)LSTM作為一種特殊的神經(jīng)網(wǎng)絡(luò)模型,它又有哪些特點呢?作為初學(xué)者,如何由淺入深地理解LSTM并將其應(yīng)用到實際工作中呢?本文將由淺入深介紹循環(huán)神經(jīng)網(wǎng)絡(luò)RNN和長短期記憶網(wǎng)絡(luò)LSTM的基本原理,并基于Pytorch實現(xiàn)一個簡單應(yīng)用例子,提供完整代碼。

1. 神經(jīng)網(wǎng)絡(luò)簡介

1.1 神經(jīng)網(wǎng)絡(luò)起源

人工神經(jīng)網(wǎng)絡(luò)(Aritificial Neural Networks, ANN)是一種仿生的網(wǎng)絡(luò)結(jié)構(gòu),起源于對人類大腦的研究。人工神經(jīng)網(wǎng)絡(luò)(Aritificial Neural Networks)也常被簡稱為神經(jīng)網(wǎng)絡(luò)(Neural Networks, NN),基本思想是通過大量簡單的神經(jīng)元之間的相互連接來構(gòu)造復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu),信號(數(shù)據(jù))可以在這些神經(jīng)元之間傳遞,通過激活不同的神經(jīng)元和對傳遞的信號進行加權(quán)來使得信號被放大或衰減,經(jīng)過多次的傳遞來改變信號的強度和表現(xiàn)形式。

神經(jīng)網(wǎng)絡(luò)最早起源于20世紀(jì)40年代,神經(jīng)科學(xué)家和控制論專家Warren McCulloch和邏輯學(xué)家Walter Pitts基于數(shù)學(xué)和閾值邏輯算法創(chuàng)造了最早的神經(jīng)網(wǎng)絡(luò)計算模型。由于當(dāng)時的計算資源有限,無法構(gòu)建層數(shù)太多的神經(jīng)網(wǎng)絡(luò)(3層以內(nèi)),因此神經(jīng)網(wǎng)絡(luò)的應(yīng)用范圍很局限。隨著計算機技術(shù)的發(fā)展,神經(jīng)網(wǎng)絡(luò)層數(shù)的增加帶來的計算負(fù)擔(dān)已經(jīng)可以被現(xiàn)代計算機解決,各位前輩大牛對于神經(jīng)網(wǎng)絡(luò)的理解也進一步加深。歷史上神經(jīng)網(wǎng)絡(luò)的發(fā)展大致經(jīng)歷了三次高潮:20世紀(jì)40年代的控制論、20世紀(jì)80年代到90年代中期的聯(lián)結(jié)主義和2006年以來的深度學(xué)習(xí)。深度學(xué)習(xí)的出現(xiàn)直接引爆了一部分應(yīng)用市場,這里有太多的案例可以講,想了解的讀者可參考下面的鏈接:

1.2 傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的缺陷

本文假設(shè)讀者已經(jīng)了解神經(jīng)網(wǎng)絡(luò)的基本原理了,如果有讀者是初次接觸神經(jīng)網(wǎng)絡(luò)的知識,這里分享一篇個人覺得非常適合初學(xué)者的文章:

1.2.1 傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的原理回顧

傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)可以用下面這張圖表示:

NN.jpg

其中:

  • 輸入層:可以包含多個神經(jīng)元,可以接收多維的信號輸入(特征信息);
  • 輸出層:可以包含多個神經(jīng)元,可以輸出多維信號;
  • 隱含層:可以包含多個神經(jīng)網(wǎng)絡(luò)層,每一層包含多個神經(jīng)元。

每層的神經(jīng)元與上一層神經(jīng)元和下一層神經(jīng)元連接(類似生物神經(jīng)元的突觸),這些連接通路用于信號傳遞。每個神經(jīng)元接收來自上一層的信號輸入,使用一定的加和規(guī)則將所有的信號輸入?yún)R聚到一起,并使用激活函數(shù)將輸入信號激活為輸出信號,再將信號傳遞到下一層。

神經(jīng)網(wǎng)絡(luò)為什么要使用激活函數(shù)?不同的激活函數(shù)有什么不同的作用?讀者可參考:

所以,影響神經(jīng)網(wǎng)絡(luò)表現(xiàn)能力的主要因素有神經(jīng)網(wǎng)絡(luò)的層數(shù)、神經(jīng)元的個數(shù)、神經(jīng)元之間的連接方式以及神經(jīng)元所采用的激活函數(shù)。神經(jīng)元之間以不同的連接方式(全連接、部分連接)組合,可以構(gòu)成不同神經(jīng)網(wǎng)絡(luò),對于不同的信號處理效果也不一樣。但是,目前依舊沒有一種通用的方法可以根據(jù)信號輸入的特征來決定神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu),這也是神經(jīng)網(wǎng)絡(luò)模型被稱為黑箱的原因之一,帶來的問題也就是模型的參數(shù)不容易調(diào)整,也不清楚其中到底發(fā)生了什么。因此,在不斷的探索當(dāng)中,前輩大牛們總結(jié)得到了許多經(jīng)典的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu):MLP、BP、FFNN、CNN、RNN等。詳細(xì)的介紹見以下鏈接:

神經(jīng)網(wǎng)絡(luò)優(yōu)點很明顯,給我們提供了構(gòu)建模型的便利,你大可不用顧及模型本身是如何作用的,只需要按照規(guī)則構(gòu)建網(wǎng)絡(luò),然后使用訓(xùn)練數(shù)據(jù)集不斷調(diào)整參數(shù),在許多問題上都能得到一個比較“能接受”的結(jié)果,然而我們對其中發(fā)生了什么是未可知的。在深度學(xué)習(xí)領(lǐng)域,許多問題都可以通過構(gòu)建深層的神經(jīng)網(wǎng)絡(luò)模型來解決。這里,我們不對神經(jīng)網(wǎng)絡(luò)的優(yōu)點做過多闡述。

1.2.2 傳統(tǒng)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的缺陷

從傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)我們可以看出,信號流從輸入層到輸出層依次流過,同一層級的神經(jīng)元之間,信號是不會相互傳遞的。這樣就會導(dǎo)致一個問題,輸出信號只與輸入信號有關(guān),而與輸入信號的先后順序無關(guān)。并且神經(jīng)元本身也不具有存儲信息的能力,整個網(wǎng)絡(luò)也就沒有“記憶”能力,當(dāng)輸入信號是一個跟時間相關(guān)的信號時,如果我們想要通過這段信號的“上下文”信息來理解一段時間序列的意思,傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)就顯得無力了。與我們?nèi)祟惖睦斫膺^程類似,我們聽到一句話時往往需要通過這句話中詞語出現(xiàn)的順序以及我們之前所學(xué)的關(guān)于這些詞語的意思來理解整段話的意思,而不是簡單的通過其中的幾個詞語來理解。

例如,在自然語言處理領(lǐng)域,我們要讓神經(jīng)網(wǎng)絡(luò)理解這樣一句話:“地球上最高的山是珠穆朗瑪峰”,按照傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),它可能會將這句話拆分為幾個單獨的詞(地球、上、最高的、山、是、珠穆朗瑪峰),分別輸入到模型之中,而不管這幾個詞之間的順序。然而,直觀上我們可以看到,這幾個詞出現(xiàn)的順序是與最終這句話要表達(dá)的意思是密切相關(guān)的,但傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)無法處理這種情況。

因此,我們需要構(gòu)建具有“記憶”能力的神經(jīng)網(wǎng)絡(luò)模型,用來處理需要理解上下文意思的信號,也就是時間序列數(shù)據(jù)。循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)就是用來處理這類信號的,RNN之所以能夠有效的處理時間序列數(shù)據(jù),主要是基于它比較特殊的運行原理。下面將介紹RNN的構(gòu)建過程和基本運行原理,然后引入長短期記憶網(wǎng)絡(luò)(LSTM)。

2. 循環(huán)神經(jīng)網(wǎng)絡(luò)RNN

本章將介紹循環(huán)神經(jīng)網(wǎng)絡(luò)的基本原理,參考了Understanding LSTM Networks一文,以及YouTue上的一個視頻Illustrated Guide to LSTM's and GRU's: A step by step explanation。

2.1 RNN的構(gòu)造過程

RNN是一種特殊的神經(jīng)網(wǎng)路結(jié)構(gòu),其本身是包含循環(huán)的網(wǎng)絡(luò),允許信息在神經(jīng)元之間傳遞,如下圖所示:

RNN-rolled.png

圖示是一個RNN結(jié)構(gòu)示意圖,圖中的 A 表示神經(jīng)網(wǎng)絡(luò)模型,X_t 表示模型的輸入信號,h_t 表示模型的輸出信號,如果沒有 A 的輸出信號傳遞到 A 的那個箭頭, 這個網(wǎng)絡(luò)模型與普通的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)無異。那么這個箭頭做了什么事情呢?它允許 A 將信息傳遞給 A ,神經(jīng)網(wǎng)絡(luò)將自己的輸出作為輸入了!這怎么理解???作者第一次看到這個圖的時候也是有點懵,讀者可以思考一分鐘。

關(guān)鍵在于輸入信號是一個時間序列,跟時間 t 有關(guān)。也就是說,在 t 時刻,輸入信號 X_t 作為神經(jīng)網(wǎng)絡(luò) A 的輸入,A 的輸出分流為兩部分,一部分輸出給 h_t ,一部分作為一個隱藏的信號流被輸入到 A 中,在下一次時刻輸入信號 X_{t+1} 時,這部分隱藏的信號流也作為輸入信號輸入到了 A 中。此時神經(jīng)網(wǎng)絡(luò) A 就同時接收了 t 時刻和 t+1 時刻的信號輸入了,此時的輸出信號又將被傳遞到下一時刻的 A 中。如果我們把上面那個圖根據(jù)時間 t 展開來看,就是:

RNN-unrolled.png

看到了嗎?t=0 時刻的信息輸出給t=1 時刻的模型 A 了,t=1 時刻的信息輸出給t=2 時刻的模型 A 了, \dots。這樣,相當(dāng)于RNN在時間序列上把自己復(fù)制了很多遍,每個模型都對應(yīng)一個時刻的輸入,并且當(dāng)前時刻的輸出還作為下一時刻的模型的輸入信號。

這樣鏈?zhǔn)降慕Y(jié)構(gòu)揭示了RNN本質(zhì)上是與序列相關(guān)的,是對于時間序列數(shù)據(jù)最自然的神經(jīng)網(wǎng)絡(luò)架構(gòu)。并且理論上,RNN可以保留以前任意時刻的信息。RNN在語音識別、自然語言處理、圖片描述、視頻圖像處理等領(lǐng)域已經(jīng)取得了一定的成果,而且還將更加大放異彩。在實際使用的時候,用得最多的一種RNN結(jié)構(gòu)是LSTM,為什么是LSTM呢?我們從普通RNN的局限性說起。

2.2 RNN的局限性

RNN利用了神經(jīng)網(wǎng)絡(luò)的“內(nèi)部循環(huán)”來保留時間序列的上下文信息,可以使用過去的信號數(shù)據(jù)來推測對當(dāng)前信號的理解,這是非常重要的進步,并且理論上RNN可以保留過去任意時刻的信息。但實際使用RNN時往往遇到問題,請看下面這個例子。

假如我們構(gòu)造了一個語言模型,可以通過當(dāng)前這一句話的意思來預(yù)測下一個詞語?,F(xiàn)在有這樣一句話:“我是一個中國人,出生在普通家庭,我最常說漢語,也喜歡寫漢字。我喜歡媽媽做的菜”。我們的語言模型在預(yù)測“我最常說漢語”的“漢語”這個詞時,它要預(yù)測“我最長說”這后面可能跟的是一個語言,可能是英語,也可能是漢語,那么它需要用到第一句話的“我是中國人”這段話的意思來推測我最常說漢語,而不是英語、法語等。而在預(yù)測“我喜歡媽媽做的菜”的最后的詞“菜”時并不需要“我是中國人”這個信息以及其他的信息,它跟我是不是一個中國人沒有必然的關(guān)系。

這個例子告訴我們,想要精確地處理時間序列,有時候我們只需要用到最近的時刻的信息。例如預(yù)測“我喜歡媽媽做的菜”最后這個詞“菜”,此時信息傳遞是這樣的:


RNN-shorttermdepdencies.png

“菜”這個詞與“我”、“喜歡”、“媽媽”、“做”、“的”這幾個詞關(guān)聯(lián)性比較大,距離也比較近,所以可以直接利用這幾個詞進行最后那個詞語的推測。

而有時候我們又需要用到很早以前時刻的信息,例如預(yù)測“我最常說漢語”最后的這個詞“漢語”。此時信息傳遞是這樣的:


RNN-longtermdependencies.png

此時,我們要預(yù)測“漢語”這個詞,僅僅依靠“我”、“最”、“?!?、“說”這幾個詞還不能得出我說的是漢語,必須要追溯到更早的句子“我是一個中國人”,由“中國人”這個詞語來推測我最常說的是漢語。因此,這種情況下,我們想要推測“漢語”這個詞的時候就比前面那個預(yù)測“菜”這個詞所用到的信息就處于更早的時刻。

而RNN雖然在理論上可以保留所有歷史時刻的信息,但在實際使用時,信息的傳遞往往會因為時間間隔太長而逐漸衰減,傳遞一段時刻以后其信息的作用效果就大大降低了。因此,普通RNN對于信息的長期依賴問題沒有很好的處理辦法。

為了克服這個問題,Hochreiter等人在1997年改進了RNN,提出了一種特殊的RNN模型——LSTM網(wǎng)絡(luò),可以學(xué)習(xí)長期依賴信息,在后面的20多年被改良和得到了廣泛的應(yīng)用,并且取得了極大的成功。

3. 長短時間記憶網(wǎng)絡(luò)(LSTM)

3.1 LSTM與RNN的關(guān)系

長短期記憶(Long Short Term Memory,LSTM)網(wǎng)絡(luò)是一種特殊的RNN模型,其特殊的結(jié)構(gòu)設(shè)計使得它可以避免長期依賴問題,記住很早時刻的信息是LSTM的默認(rèn)行為,而不需要專門為此付出很大代價。

普通的RNN模型中,其重復(fù)神經(jīng)網(wǎng)絡(luò)模塊的鏈?zhǔn)侥P腿缦聢D所示,這個重復(fù)的模塊只有一個非常簡單的結(jié)構(gòu),一個單一的神經(jīng)網(wǎng)絡(luò)層(例如tanh層),這樣就會導(dǎo)致信息的處理能力比較低。


LSTM3-SimpleRNN.png

而LSTM在此基礎(chǔ)上將這個結(jié)構(gòu)改進了,不再是單一的神經(jīng)網(wǎng)絡(luò)層,而是4個,并且以一種特殊的方式進行交互。


LSTM3-chain.png

粗看起來,這個結(jié)構(gòu)有點復(fù)雜,不過不用擔(dān)心,接下來我們會慢慢解釋。在解釋這個神經(jīng)網(wǎng)絡(luò)層時我們先來認(rèn)識一些基本的模塊表示方法。圖中的模塊分為以下幾種:


LSTM2-notation.png
  • 黃色方塊:表示一個神經(jīng)網(wǎng)絡(luò)層(Neural Network Layer);
  • 粉色圓圈:表示按位操作或逐點操作(pointwise operation),例如向量加和、向量乘積等;
  • 單箭頭:表示信號傳遞(向量傳遞);
  • 合流箭頭:表示兩個信號的連接(向量拼接);
  • 分流箭頭:表示信號被復(fù)制后傳遞到2個不同的地方。

下面我們將分別介紹這些模塊如何在LSTM中作用。

3.2 LSTM的基本思想

LSTM的關(guān)鍵是細(xì)胞狀態(tài)(直譯:cell state),表示為 C_t ,用來保存當(dāng)前LSTM的狀態(tài)信息并傳遞到下一時刻的LSTM中,也就是RNN中那根“自循環(huán)”的箭頭。當(dāng)前的LSTM接收來自上一個時刻的細(xì)胞狀態(tài) C_{t-1} ,并與當(dāng)前LSTM接收的信號輸入 x_t 共同作用產(chǎn)生當(dāng)前LSTM的細(xì)胞狀態(tài) C_t,具體的作用方式下面將詳細(xì)介紹。

LSTM3-C-line.png

在LSTM中,采用專門設(shè)計的“門”來引入或者去除細(xì)胞狀態(tài) C_t 中的信息。門是一種讓信息選擇性通過的方法。有的門跟信號處理中的濾波器有點類似,允許信號部分通過或者通過時被門加工了;有的門也跟數(shù)字電路中的邏輯門類似,允許信號通過或者不通過。這里所采用的門包含一個 sigmod 神經(jīng)網(wǎng)絡(luò)層和一個按位的乘法操作,如下圖所示:

LSTM3-gate.png

其中黃色方塊表示sigmod神經(jīng)網(wǎng)絡(luò)層,粉色圓圈表示按位乘法操作。sigmod神經(jīng)網(wǎng)絡(luò)層可以將輸入信號轉(zhuǎn)換為 01 之間的數(shù)值,用來描述有多少量的輸入信號可以通過。0 表示“不允許任何量通過”,1 表示“允許所有量通過”。sigmod神經(jīng)網(wǎng)絡(luò)層起到類似下圖的sigmod函數(shù)所示的作用:

sigmod_function.jpg

其中,橫軸表示輸入信號,縱軸表示經(jīng)過 sigmod 以后的輸出信號。

LSTM主要包括三個不同的門結(jié)構(gòu):遺忘門、記憶門和輸出門。這三個門用來控制LSTM的信息保留和傳遞,最終反映到細(xì)胞狀態(tài) C_t 和輸出信號 h_t 。如下圖所示:

LSTM_gates.png

圖中標(biāo)示了LSTM中各個門的構(gòu)成情況和相互之間的關(guān)系,其中:

  • 遺忘門由一個sigmod神經(jīng)網(wǎng)絡(luò)層和一個按位乘操作構(gòu)成;
  • 記憶門由輸入門(input gate)與tanh神經(jīng)網(wǎng)絡(luò)層和一個按位乘操作構(gòu)成;
  • 輸出門(output gate)與 tanh 函數(shù)(注意:這里不是 tanh 神經(jīng)網(wǎng)絡(luò)層)以及按位乘操作共同作用將細(xì)胞狀態(tài)和輸入信號傳遞到輸出端。

3.3 遺忘門

顧名思義,遺忘門的作用就是用來“忘記”信息的。在LSTM的使用過程中,有一些信息不是必要的,因此遺忘門的作用就是用來選擇這些信息并“忘記”它們。遺忘門決定了細(xì)胞狀態(tài) C_{t-1} 中的哪些信息將被遺忘。那么遺忘門的工作原理是什么呢?看下面這張圖。

LSTM3-focus-f.png

左邊高亮的結(jié)構(gòu)就是遺忘門了,包含一個sigmod神經(jīng)網(wǎng)絡(luò)層(黃色方框,神經(jīng)網(wǎng)絡(luò)參數(shù)為 W_f, b_f),接收 t 時刻的輸入信號 x_tt-1 時刻LSTM的上一個輸出信號 h_{t-1} ,這兩個信號進行拼接以后共同輸入到sigmod神經(jīng)網(wǎng)絡(luò)層中,然后輸出信號 f_t,f_t是一個 01之間的數(shù)值,并與 C_{t-1} 相乘來決定 C_{t-1}中的哪些信息將被保留,哪些信息將被舍棄??赡芸吹竭@里有的初學(xué)者還是不知道具體是什么意思,我們用一個簡單的例子來說明。

假設(shè) C_{t-1} = [0.5, 0.6, 0.4], h_{t-1}=[0.3, 0.8, 0.69], x_t=[0.2,1.3,0.7], 那么遺忘門的輸入信號就是h_{t-1}x_t的組合,即 [h_{t-1}, x_t] = [0.3, 0.8, 0.69, 0.2, 1.3, 0.7], 然后通過sigmod 神經(jīng)網(wǎng)絡(luò)層輸出每一個元素都處于01之間的向量f_t=[0.5, 0.1, 0.8],注意,此時f_t是一個與c_{t-1}維數(shù)相同的向量,此處為3維。如果看到這里還沒有看懂的讀者,可能會有這樣的疑問:輸入信號明明是6維的向量,為什么 f_t就變成了3維呢?這里可能是將sigmod神經(jīng)網(wǎng)絡(luò)層當(dāng)成了sigmod激活函數(shù)了,兩者不是一個東西,初學(xué)者在這里很容易混淆。下文所提及的sigmod神經(jīng)網(wǎng)絡(luò)層和tanh神經(jīng)網(wǎng)絡(luò)層而是類似的道理,他們并不是簡單的sigmod激活函數(shù)和tanh激活函數(shù),在學(xué)習(xí)時要注意區(qū)分。

3.4 記憶門

記憶門的作用與遺忘門相反,它將決定新輸入的信息 x_th_{t-1} 中哪些信息將被保留。

LSTM3-focus-i.png

如圖所示,記憶門包含2個部分。第一個是包含sigmod神經(jīng)網(wǎng)絡(luò)層(輸入門,神經(jīng)網(wǎng)絡(luò)網(wǎng)絡(luò)參數(shù)為 W_i,b_i)和一個 tanh 神經(jīng)網(wǎng)絡(luò)層(神經(jīng)網(wǎng)絡(luò)參數(shù)為 W_c,b_c)。

  • sigmod神經(jīng)網(wǎng)絡(luò)層的作用很明顯,跟遺忘門一樣,它接收 x_th_{t-1} 作為輸入,然后輸出一個 01 之間的數(shù)值 i_t 來決定哪些信息需要被更新;
  • Tanh神經(jīng)網(wǎng)絡(luò)層的作用是將輸入的 x_th_{t-1} 整合,然后通過一個tanh神經(jīng)網(wǎng)絡(luò)層來創(chuàng)建一個新的狀態(tài)候選向量 \widetilde{C}_t ,\widetilde{C}_t 的值范圍在 -11 之間。

記憶門的輸出由上述兩個神經(jīng)網(wǎng)絡(luò)層的輸出決定,i_t\widetilde{C}_t 相乘來選擇哪些信息將被新加入到 t 時刻的細(xì)胞狀態(tài) C_t 中。

3.5 更新細(xì)胞狀態(tài)

有了遺忘門和記憶門,我們就可以更新細(xì)胞狀態(tài) C_t 了。

LSTM3-focus-C.png

這里將遺忘門的輸出 f_t 與上一時刻的細(xì)胞狀態(tài) C_{t-1} 相乘來選擇遺忘和保留一些信息,將記憶門的輸出與從遺忘門選擇后的信息加和得到新的細(xì)胞狀態(tài) C_t。這就表示 t 時刻的細(xì)胞狀態(tài) C_t 已經(jīng)包含了此時需要丟棄的 t-1 時刻傳遞的信息和 t 時刻從輸入信號獲取的需要新加入的信息 i_t \cdot \widetilde{C}_t 。C_t 將繼續(xù)傳遞到 t+1 時刻的LSTM網(wǎng)絡(luò)中,作為新的細(xì)胞狀態(tài)傳遞下去。

3.6 輸出門

前面已經(jīng)講了LSTM如何來更新細(xì)胞狀態(tài) C_t, 那么在 t 時刻我們輸入信號 x_t 以后,對應(yīng)的輸出信號該如何計算呢?

LSTM3-focus-o.png

如上面左圖所示,輸出門就是將t-1時刻傳遞過來并經(jīng)過了前面遺忘門與記憶門選擇后的細(xì)胞狀態(tài)C_{t-1}, 與 t-1 時刻的輸出信號 h_{t-1}t 時刻的輸入信號 x_t 整合到一起作為當(dāng)前時刻的輸出信號。整合的過程如上圖所示,x_th_{t-1} 經(jīng)過一個sigmod神經(jīng)網(wǎng)絡(luò)層(神經(jīng)網(wǎng)絡(luò)參數(shù)為 W_o,b_o)輸出一個 01 之間的數(shù)值 o_t。C_t 經(jīng)過一個tanh函數(shù)(注意:這里不是 tanh 神經(jīng)網(wǎng)絡(luò)層)到一個在 -11 之間的數(shù)值,并與o_t 相乘得到輸出信號 h_t ,同時 h_t 也作為下一個時刻的輸入信號傳遞到下一階段。

其中,tanh函數(shù)是激活函數(shù)的一種,函數(shù)圖像為:

tanh.png

至此,基本的LSTM網(wǎng)絡(luò)模型就介紹完了。如果對LSTM模型還沒有理解到的,可以看一下這個視頻,作者是一個外國小哥,英文講解的,有動圖,方便理解。

3.7 LSTM的一些變體

前面已經(jīng)介紹了基本的LSTM網(wǎng)絡(luò)模型,而實際應(yīng)用時,我們常常會采用LSTM的一些變體,雖然差異不大,這里不再做詳細(xì)介紹,有興趣的讀者可以自行了解。

3.7.1 在門上增加窺視孔

LSTM3-var-peepholes.png

這是2000年Gers和Schemidhuber教授提出的一種LSTM變體。圖中,在傳統(tǒng)的LSTM結(jié)構(gòu)基礎(chǔ)上,每個門(遺忘門、記憶門和輸出門)增加了一個“窺視孔”(Peephole),有的學(xué)者在使用時也選擇只對部分門加入窺視孔。

3.7.2 整合遺忘門和輸入門

LSTM3-var-tied.png

與傳統(tǒng)的LSTM不同的是,這個變體不需要分開來確定要被遺忘和記住的信息,采用一個結(jié)構(gòu)搞定。在遺忘門的輸出信號值(01之間)上,用 1 減去該數(shù)值來作為記憶門的狀態(tài)選擇,表示只更新需要被遺忘的那些信息的狀態(tài)。

3.7.3 GRU

改進比較大的一個LSTM變體叫Gated Recurrent Unit (GRU),目前應(yīng)用較多。結(jié)構(gòu)圖如下


LSTM3-var-GRU.png

GRU主要包含2個門:重置門和更新門。GRU混合了細(xì)胞狀態(tài) C_t 和隱藏狀態(tài) h_{t-1} 為一個新的狀態(tài),使用 h_t 來表示。 該模型比傳統(tǒng)的標(biāo)準(zhǔn)LSTM模型簡單。

4. 基于Pytorch的LSTM代碼實現(xiàn)

Pytorch是Python的一個機器學(xué)習(xí)包,與Tensorflow類似,Pytorch非常適合用來構(gòu)建神經(jīng)網(wǎng)絡(luò)模型,并且已經(jīng)提供了一些常用的神經(jīng)網(wǎng)絡(luò)模型包,用戶可以直接調(diào)用。下面我們就用一個簡單的小例子來說明如何使用Pytorch來構(gòu)建LSTM模型。

我們使用正弦函數(shù)和余弦函數(shù)來構(gòu)造時間序列,而正余弦函數(shù)之間是成導(dǎo)數(shù)關(guān)系,所以我們可以構(gòu)造模型來學(xué)習(xí)正弦函數(shù)與余弦函數(shù)之間的映射關(guān)系,通過輸入正弦函數(shù)的值來預(yù)測對應(yīng)的余弦函數(shù)的值。

正弦函數(shù)和余弦函數(shù)對應(yīng)關(guān)系圖如下圖所示:


demo_sine_cosine.png

可以看到,每一個函數(shù)曲線上,每一個正弦函數(shù)的值都對應(yīng)一個余弦函數(shù)值。但其實如果只關(guān)心正弦函數(shù)的值本身而不考慮當(dāng)前值所在的時間,那么正弦函數(shù)值和余弦函數(shù)值不是一一對應(yīng)關(guān)系。例如,當(dāng) t=2.5t=6.8 時,sin(t)=0.5 ,但在這兩個不同的時刻,cos(t) 的值卻不一樣,也就是說如果不考慮時間,同一個正弦函數(shù)值可能對應(yīng)了不同的幾個余弦函數(shù)值。對于傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)來說,它僅僅基于當(dāng)前的輸入來預(yù)測輸出,對于這種同一個輸入可能對應(yīng)多個輸出的情況不再適用。

我們?nèi)≌液瘮?shù)的值作為LSTM的輸入,來預(yù)測余弦函數(shù)的值?;赑ytorch來構(gòu)建LSTM模型,采用1個輸入神經(jīng)元,1個輸出神經(jīng)元,16個隱藏神經(jīng)元作為LSTM網(wǎng)絡(luò)的構(gòu)成參數(shù),平均絕對誤差(LMSE)作為損失誤差,使用Adam優(yōu)化算法來訓(xùn)練LSTM神經(jīng)網(wǎng)絡(luò)?;贏naconda和Python3.6的完整代碼如下:

# -*- coding:UTF-8 -*-
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt

# Define LSTM Neural Networks
class LstmRNN(nn.Module):
    """
        Parameters:
        - input_size: feature size
        - hidden_size: number of hidden units
        - output_size: number of output
        - num_layers: layers of LSTM to stack
    """
    def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):
        super().__init__()
 
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers) # utilize the LSTM model in torch.nn 
        self.forwardCalculation = nn.Linear(hidden_size, output_size)
 
    def forward(self, _x):
        x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)
        s, b, h = x.shape  # x is output, size (seq_len, batch, hidden_size)
        x = x.view(s*b, h)
        x = self.forwardCalculation(x)
        x = x.view(s, b, -1)
        return x

if __name__ == '__main__':
    # create database
    data_len = 200
    t = np.linspace(0, 12*np.pi, data_len)
    sin_t = np.sin(t)
    cos_t = np.cos(t)

    dataset = np.zeros((data_len, 2))
    dataset[:,0] = sin_t
    dataset[:,1] = cos_t
    dataset = dataset.astype('float32')

    # plot part of the original dataset
    plt.figure()
    plt.plot(t[0:60], dataset[0:60,0], label='sin(t)')
    plt.plot(t[0:60], dataset[0:60,1], label = 'cos(t)')
    plt.plot([2.5, 2.5], [-1.3, 0.55], 'r--', label='t = 2.5') # t = 2.5
    plt.plot([6.8, 6.8], [-1.3, 0.85], 'm--', label='t = 6.8') # t = 6.8
    plt.xlabel('t')
    plt.ylim(-1.2, 1.2)
    plt.ylabel('sin(t) and cos(t)')
    plt.legend(loc='upper right')

    # choose dataset for training and testing
    train_data_ratio = 0.5 # Choose 80% of the data for testing
    train_data_len = int(data_len*train_data_ratio)
    train_x = dataset[:train_data_len, 0]
    train_y = dataset[:train_data_len, 1]
    INPUT_FEATURES_NUM = 1
    OUTPUT_FEATURES_NUM = 1
    t_for_training = t[:train_data_len]

    # test_x = train_x
    # test_y = train_y
    test_x = dataset[train_data_len:, 0]
    test_y = dataset[train_data_len:, 1]
    t_for_testing = t[train_data_len:]

    # ----------------- train -------------------
    train_x_tensor = train_x.reshape(-1, 5, INPUT_FEATURES_NUM) # set batch size to 5
    train_y_tensor = train_y.reshape(-1, 5, OUTPUT_FEATURES_NUM) # set batch size to 5
 
    # transfer data to pytorch tensor
    train_x_tensor = torch.from_numpy(train_x_tensor)
    train_y_tensor = torch.from_numpy(train_y_tensor)
    # test_x_tensor = torch.from_numpy(test_x)
 
    lstm_model = LstmRNN(INPUT_FEATURES_NUM, 16, output_size=OUTPUT_FEATURES_NUM, num_layers=1) # 16 hidden units
    print('LSTM model:', lstm_model)
    print('model.parameters:', lstm_model.parameters)
 
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(lstm_model.parameters(), lr=1e-2)
 
    max_epochs = 10000
    for epoch in range(max_epochs):
        output = lstm_model(train_x_tensor)
        loss = loss_function(output, train_y_tensor)
 
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
 
        if loss.item() < 1e-4:
            print('Epoch [{}/{}], Loss: {:.5f}'.format(epoch+1, max_epochs, loss.item()))
            print("The loss value is reached")
            break
        elif (epoch+1) % 100 == 0:
            print('Epoch: [{}/{}], Loss:{:.5f}'.format(epoch+1, max_epochs, loss.item()))
 
    # prediction on training dataset
    predictive_y_for_training = lstm_model(train_x_tensor)
    predictive_y_for_training = predictive_y_for_training.view(-1, OUTPUT_FEATURES_NUM).data.numpy()

    # torch.save(lstm_model.state_dict(), 'model_params.pkl') # save model parameters to files
 
    # ----------------- test -------------------
    # lstm_model.load_state_dict(torch.load('model_params.pkl'))  # load model parameters from files
    lstm_model = lstm_model.eval() # switch to testing model

    # prediction on test dataset
    test_x_tensor = test_x.reshape(-1, 5, INPUT_FEATURES_NUM) # set batch size to 5, the same value with the training set
    test_x_tensor = torch.from_numpy(test_x_tensor)
 
    predictive_y_for_testing = lstm_model(test_x_tensor)
    predictive_y_for_testing = predictive_y_for_testing.view(-1, OUTPUT_FEATURES_NUM).data.numpy()
 
    # ----------------- plot -------------------
    plt.figure()
    plt.plot(t_for_training, train_x, 'g', label='sin_trn')
    plt.plot(t_for_training, train_y, 'b', label='ref_cos_trn')
    plt.plot(t_for_training, predictive_y_for_training, 'y--', label='pre_cos_trn')

    plt.plot(t_for_testing, test_x, 'c', label='sin_tst')
    plt.plot(t_for_testing, test_y, 'k', label='ref_cos_tst')
    plt.plot(t_for_testing, predictive_y_for_testing, 'm--', label='pre_cos_tst')

    plt.plot([t[train_data_len], t[train_data_len]], [-1.2, 4.0], 'r--', label='separation line') # separation line

    plt.xlabel('t')
    plt.ylabel('sin(t) and cos(t)')
    plt.xlim(t[0], t[-1])
    plt.ylim(-1.2, 4)
    plt.legend(loc='upper right')
    plt.text(14, 2, "train", size = 15, alpha = 1.0)
    plt.text(20, 2, "test", size = 15, alpha = 1.0)

    plt.show()

訓(xùn)練的過程如下:


LSTM_training.png

該模型在訓(xùn)練集和測試集上的結(jié)果如下:


demo_LSTM.png

圖中,紅色虛線的左邊表示該模型在訓(xùn)練數(shù)據(jù)集上的表現(xiàn),右邊表示該模型在測試數(shù)據(jù)集上的表現(xiàn)??梢钥吹剑褂肔STM構(gòu)建訓(xùn)練模型,我們可以僅僅使用正弦函數(shù)在 t 時刻的值作為輸入來準(zhǔn)確預(yù)測 t 時刻的余弦函數(shù)值,不用額外添加當(dāng)前的時間信息、速度信息等。

5. 參考鏈接

  1. Undterstanding LSTM Networks
  2. 推薦看這個Youtube視頻
  3. 循環(huán)神經(jīng)網(wǎng)絡(luò)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 9. 循環(huán)神經(jīng)網(wǎng)絡(luò) 場景描述 循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network)是一種主流的深度學(xué)習(xí)...
    _龍雀閱讀 2,979評論 0 3
  • 作者 |Edwin Chen 編譯 | AI100 第一次接觸長短期記憶神經(jīng)網(wǎng)絡(luò)(LSTM)時,我驚呆了。 原來,...
    王逸飛_1cd9閱讀 3,412評論 0 19
  • 最近工作的時候,發(fā)現(xiàn)自己的一個狀態(tài):腦子有點空空的??赡苁亲罱嘤?xùn)有點多,吸收不過來吧!新知識間的聯(lián)系還不密切,用...
    賢菜花閱讀 195評論 1 0
  • 生活中,每個人都有自己在意的人或事,記住美好,忘記不好,人就會活得輕松快樂。春享花香,夏披霞光,秋觀圓月,冬...
    半杯水_31a7閱讀 217評論 1 1
  • 如果古代有舞蹈評級,她的舞技應(yīng)該問鼎最高級別了,史傳她可掌上舞。 她雖出身不高,卻靠跳舞贏得皇帝盛寵不衰,進而封后...
    葉紫檀閱讀 3,670評論 12 38

友情鏈接更多精彩內(nèi)容