Informer:超越Transformer的長序列預(yù)測模型

逍遙一身作品

背景介紹

序列預(yù)測在現(xiàn)實(shí)中比較基礎(chǔ)但又十分重要的研究場景,如商品銷量及庫存的預(yù)測,未來某個(gè)時(shí)刻某只股票的價(jià)格,疾病的傳播及擴(kuò)散預(yù)測等。而長序列預(yù)測(Long Sequence Time-Series Forecasting,以下稱為 LSTF)方面的工作因?yàn)槠錃v史數(shù)據(jù)量大、計(jì)算復(fù)雜性高、預(yù)測精度要求高更是成為序列預(yù)測研究的熱點(diǎn),但一直以來長序列預(yù)測并沒有取得太好的效果。在2021年的AAAI 最佳論文獎(jiǎng)項(xiàng)中就有一篇來自北京航空航天大學(xué)的工作:Informer,其主要的工作是使用 Transfomer 實(shí)現(xiàn)長序列預(yù)測,實(shí)驗(yàn)表明取得了不錯(cuò)的結(jié)果,本文就給大家?guī)韺?duì)這篇工作的解讀。

論文標(biāo)題:

Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting

論文鏈接:

https://arxiv.org/abs/2012.07436

源碼鏈接:

https://github.com/zhouhaoyi/ETDataset

2. 相關(guān)工作

長序列時(shí)間序列預(yù)測(Long sequence time-series forecasting,LSTF)要求模型具有較高的預(yù)測能力,即能夠準(zhǔn)確地捕捉輸出與輸入之間的長期依賴關(guān)系。近年來,針對(duì)序列預(yù)測問題的研究主要都集中在短序列的預(yù)測上,輸入序列越長,傳統(tǒng)模型的計(jì)算復(fù)雜度越高,同時(shí)預(yù)測能力顯著降低。文中采用LSTM網(wǎng)絡(luò)對(duì)一組電子變壓器數(shù)據(jù)進(jìn)行了短時(shí)序列(12個(gè)數(shù)據(jù)點(diǎn),0.5天數(shù)據(jù))和長時(shí)序列(480個(gè)數(shù)據(jù)點(diǎn),20天數(shù)據(jù))分別進(jìn)行了實(shí)驗(yàn),結(jié)果發(fā)現(xiàn),隨著序列長度的增加,預(yù)測誤差顯著增加,同時(shí)預(yù)測速度急劇下降。

(LSTM預(yù)測誤差和效率對(duì)比)

近年來的研究表明,相對(duì)RNN類型的模型來看,Transformer在長距離依賴的表達(dá)方面表現(xiàn)出了較高的潛力。然而,Transformer存在幾個(gè)嚴(yán)重的問題,使其不能直接適用于LSTF問題,例、高內(nèi)存使用量編碼器-解碼器體系結(jié)構(gòu)固有的局限性。其中Transformer模型主要存在下面三個(gè)問題

self-attention機(jī)制的二次計(jì)算復(fù)雜度問題:self-attention機(jī)制的點(diǎn)積操作使每層的時(shí)間復(fù)雜度和內(nèi)存使用量為 。

高內(nèi)存使用量問題:對(duì)長序列輸入進(jìn)行堆疊時(shí),J個(gè)encoder-decoder層的堆棧使總內(nèi)存使用量為 ,這限制了模型在接收長序列輸入時(shí)的可伸縮性。

預(yù)測長期輸出的效率問題:Transformer的動(dòng)態(tài)解碼過程,使得輸出變?yōu)橐粋€(gè)接一個(gè)的輸出,后面的輸出依賴前面一個(gè)時(shí)間步的預(yù)測結(jié)果,這樣會(huì)導(dǎo)致推理非常慢。

基于Transformer的改進(jìn)工作,學(xué)術(shù)界也進(jìn)行得如火如荼,如The Sparse Transformer (Child et al. 2019),

LogSparse Transformer (Li et al. 2019), and Longformer

(Beltagy, Peters, and Cohan 2020) 均使用一種稀疏化方法來處理Transformer的時(shí)間復(fù)雜度問題,將時(shí)間復(fù)雜度降到了 O(Llog L), 但這些模型均采用一種啟發(fā)式方法來決定選擇哪些attention來進(jìn)行計(jì)算,不具有通用性。Reformer (Kitaev, Kaiser,

and Levskaya 2019) 使用本地哈?;椒ㄓ?jì)算attention,同樣取得O(Llog L)的時(shí)間復(fù)雜度,但它只能處理特別長的序列 。最近, Linformer (Wang

et al. 2020) 宣稱可以達(dá)到O(L)的時(shí)間復(fù)雜度,但其矩陣維度在真實(shí)場景的長序列預(yù)測中無法完全確定,有退化到O(L2)的風(fēng)險(xiǎn)。其它一些模型壓縮方法也都僅僅解決attention計(jì)算時(shí)的二次時(shí)間復(fù)雜度的問題,針對(duì)長序列的輸入和輸出導(dǎo)致的內(nèi)存使用量偏高和輸出效率低的問題都沒有解決。

模型特點(diǎn)

這篇文章中,作者針對(duì)Transformer模型,提出了下面的目標(biāo):能否改進(jìn)Transformer模型,使其計(jì)算、內(nèi)存和體系結(jié)構(gòu)更高效,同時(shí)保持更高的預(yù)測能力?為了達(dá)到這些目標(biāo),這篇文章中設(shè)計(jì)了一種基于改進(jìn)Transformer的LSTF模型,即Informer模型,該模型具有三個(gè)顯著特征:

一種ProbSpare self-attention機(jī)制,它可以在時(shí)間復(fù)雜度和內(nèi)存使用方面達(dá)到 。

self-attention蒸餾機(jī)制,通過對(duì)每個(gè)attention層結(jié)果上套一個(gè)Conv1D,再加一個(gè)Maxpooling層,來對(duì)每層輸出減半來突出主導(dǎo)注意,并有效地處理過長的輸入序列。

并行生成式解碼器機(jī)制,對(duì)長時(shí)間序列進(jìn)行一次前向計(jì)算輸出所有預(yù)測結(jié)果而不是逐步的方式進(jìn)行預(yù)測,這大大提高了長序列預(yù)測的推理速度。

在介紹模型之前,本文先對(duì)模型的輸入輸出、編解碼器結(jié)構(gòu)和輸入表示進(jìn)行介紹(熟悉編解碼器的同學(xué)可以略過)。

3.預(yù)備知識(shí)

01 輸入輸出形式化表示

輸入:

輸出:

其中dy>1

02 編解碼結(jié)構(gòu)

通常,編碼-解碼結(jié)構(gòu)是這樣設(shè)計(jì)的:將輸入Xt編碼為隱層狀態(tài)Ht,然后將隱層狀態(tài)解碼為輸出Yt。推理階段則采用逐步推理方式,即動(dòng)態(tài)解碼,一個(gè)時(shí)間步解碼完成后再解碼另外一個(gè)時(shí)間步的結(jié)果。

具體為:通過前一步隱層狀態(tài)

和其它前一步輸出,計(jì)算k+1步的隱層狀態(tài)

,然后預(yù)測第k+1步的輸出

03 輸入表示

在長序列時(shí)序建模問題中,不僅需要局部時(shí)序信息還需要層次時(shí)序信息,如星期、月和年等,以及突發(fā)時(shí)間戳信息(事件或某些節(jié)假日等)。常規(guī)自注意力機(jī)制很難直接適配,可能會(huì)帶來編碼器和解碼器之間的 query 和 key 的錯(cuò)誤匹配問題,最終影響預(yù)測效果。

如上圖所示,位置嵌入分為了三種:

局部時(shí)間戳:即Transformer中的位置編碼。計(jì)算公式為:

2. 全局時(shí)間戳。這里使用的可學(xué)習(xí)嵌入表示

具體實(shí)現(xiàn)時(shí),構(gòu)建一個(gè)詞匯表,使用Embedding層表示每一個(gè)“詞匯”,為對(duì)齊維度,使用1D卷積將輸入標(biāo)量

轉(zhuǎn)為向量

計(jì)算公式為:

4. 網(wǎng)絡(luò)結(jié)構(gòu)

文中提出的Informer模型整體上仍然保存了Encoder-Decoder的架構(gòu),其整體框架如下圖所示:

在上圖左邊為編碼器,其編碼過程為:編碼器接收長序列輸入(綠色部分),通過ProbSparse自注意力模塊和自注意力蒸餾模塊,得到特征表示。右邊為解碼器,解碼過程為:解碼器接收長序列輸入(預(yù)測目標(biāo)部分設(shè)置為0),通過多頭注意力與編碼特征進(jìn)行交互,最后一次性直接預(yù)測輸出目標(biāo)部分(橙黃色部分)。

01 高效的自注意力機(jī)制

首先,讓我們回顧一下經(jīng)典的自注意力機(jī)制:根據(jù)三個(gè)輸入(query, key, value),使用縮放點(diǎn)積計(jì)算輸入的注意力矩陣:

其中,

第i個(gè)Query的attention系數(shù)的概率形式是:

該文作者通過實(shí)驗(yàn)驗(yàn)證了經(jīng)典自注意力機(jī)制存在稀疏性,即自注意力特征圖的長尾分布現(xiàn)象。具體為如下圖所示,較少的點(diǎn)積貢獻(xiàn)了絕大部分的注意力得分,也就是說其余部分的成對(duì)點(diǎn)積可以忽略。

為了度量query的稀疏性,作者使用KL散度來計(jì)算query的attention概率分布與均勻分布的概率分布的相對(duì)熵,其中第i個(gè)query的稀疏性的評(píng)價(jià)公式是:

其中第一項(xiàng)是 對(duì)于所有的key的Log-Sum-Exp (LSE),第二項(xiàng)是它們的算數(shù)平均值。

按照上面的評(píng)價(jià)方式,就可以得到ProbSparse self-attetion的公式,即:

其中, 是和 具有相同尺寸的稀疏矩陣,并且它只包含在稀疏評(píng)估 下top-u的queries。其中,u的大小通過一個(gè)采樣參數(shù)來決定。這使得ProbSparse self-attention對(duì)于每個(gè)query-key只需要計(jì)算 點(diǎn)積操作。另外,為了消除Log-Sum-Exp (LSE)在計(jì)算過程中由于計(jì)算精度截?cái)鄬?dǎo)致的數(shù)值穩(wěn)定性問題,其對(duì)稀疏評(píng)估進(jìn)行了上邊界的計(jì)算,從而保證了計(jì)算的時(shí)間和空間復(fù)雜度為 。

通過該方法解決了Transformer的第一個(gè)問題,即:self-attention機(jī)制的二次計(jì)算復(fù)雜度問題。

02 編碼器

編碼器的設(shè)計(jì)目標(biāo)為:在內(nèi)存有限的情況下,通過注意力蒸餾機(jī)制,使得單個(gè)層級(jí)特征在時(shí)間維度減半,從而允許編碼器處理更長的序列輸入。

編碼器的主要功能是捕獲長序列輸入之間的長范圍依賴。

作為ProbSpare自注意機(jī)制的結(jié)果,encoder的特征映射存在值V的冗余組合,因此,這里利用distilling操作對(duì)具有主導(dǎo)注意力的優(yōu)勢特征賦予更高權(quán)重,并在下一層生成focus self-attention特征映射。從j到j(luò)+1層的distilling操作的過程如下:

其中, 包含了multi-head probsparse self-attention以及在attention block中的關(guān)鍵操作。Conv1d表示時(shí)間序列上的一維卷積操作,并通過ELU作為了激活函數(shù),最后再進(jìn)行最大池化操作。

同時(shí)為了增強(qiáng)注意力蒸餾機(jī)制的魯棒性,作者還構(gòu)建了主序列的多個(gè)長度減半副本,每個(gè)副本是前一個(gè)副本長度的一半,這些副本序列和主序列經(jīng)過同樣的注意力蒸餾機(jī)制,從而構(gòu)建多個(gè)長度為L/4的特征圖。

最后將這些特征圖拼接起來,構(gòu)成輸入編碼器最終的特征圖。

通過以上方法,可以逐級(jí)減少特征圖的大小,在計(jì)算空間上不會(huì)耗費(fèi)過多的內(nèi)存,解決了Transformer的第二個(gè)問題:高內(nèi)存使用量問題

03 解碼器

解碼器的設(shè)計(jì)目標(biāo)為:通過一次前向計(jì)算,預(yù)測長序列的所有輸出。

原始的解碼器結(jié)構(gòu),即堆疊兩個(gè)相同的多頭注意力層,是一個(gè)動(dòng)態(tài)解碼過程,即一步一步進(jìn)行迭代輸出。不同的是,本文采用的是批量生成式預(yù)測直接輸出多步預(yù)測結(jié)果。此外,喂進(jìn)解碼器的輸入也有所不同:

。其中

表示占位符(預(yù)測值);

表示開始“字符”,這個(gè)設(shè)置就比較簡單有效。該方法源于NLP技術(shù)中的動(dòng)態(tài)解碼,原本由單個(gè)字符作為“開始字符”,作者將其擴(kuò)展為生成式方式,即從靠近預(yù)測目標(biāo)的輸入序列中動(dòng)態(tài)采樣部分序列作為“開始Token”。于是,通過多個(gè)output一起計(jì)算輸出的方法,解決了Transformer在LSTF中的第三個(gè)問題:預(yù)測長期輸出的效率問題。

5.實(shí)驗(yàn)結(jié)果

01 數(shù)據(jù)集

該工作使用兩個(gè)真實(shí)數(shù)據(jù)集和兩個(gè)公開評(píng)測數(shù)據(jù)集。

1. ETT(電力變壓器溫度)數(shù)據(jù)集,從中國兩個(gè)不同的城市收集的兩年數(shù)據(jù),傳感器每15分鐘采集一個(gè)樣本點(diǎn),按照兩種采樣粒度 {1 小時(shí),15 分鐘}將原始數(shù)據(jù)變?yōu)閮蓚€(gè)數(shù)據(jù)集;

2. ECL(耗電量),收集了 321 個(gè)客戶端的用電消耗,共22個(gè)月的數(shù)據(jù);

3. 天氣數(shù)據(jù)集,美國 1600 各地區(qū) 4 年的氣象數(shù)據(jù),采樣粒度為1小時(shí)。

02 實(shí)驗(yàn)結(jié)果分析

實(shí)驗(yàn)結(jié)果分析是我們需要重點(diǎn)關(guān)注的部分。

無論是單變量的長序列預(yù)測還是多變量的長序列預(yù)測,Informer 均能在多數(shù)數(shù)據(jù)集上取得優(yōu)于baseline的表現(xiàn)(評(píng)估指標(biāo)MSE和MAE)。從單變量時(shí)序預(yù)測來看:1. 在兩種評(píng)測指標(biāo)上,Informer 能夠取得不錯(cuò)的提升,并且隨著預(yù)測序列長度的增加,推理速度和預(yù)測誤差增長地較為緩慢;2. Informer 相較原始自注意力的 infomer,取得最優(yōu)的次數(shù)最多(28>14),并且優(yōu)于 LogTansformer 和 Reformer,說明query的稀疏性假設(shè)在很多數(shù)據(jù)集上是成立的;3. Informer 相較 LSTMa 有巨大的提升,說明相較 RNN 類模型,自注意力機(jī)制中較短網(wǎng)絡(luò)路徑模型具有更強(qiáng)的預(yù)測能力。從多變量預(yù)測來看,Informer 依然擊敗了baseline,可以得出與單變量時(shí)序預(yù)測相同的結(jié)論。并且可以很方便地將單變量預(yù)測變?yōu)槎嘧兞款A(yù)測,只需要調(diào)整最后的全連接層。

03 參數(shù)敏感度分析

作者這里通過不同參數(shù)的調(diào)整對(duì)模型最終性能表現(xiàn)進(jìn)行了評(píng)測。參數(shù)主要包括輸入長度、ProbSparse 的采樣因子和堆疊的網(wǎng)絡(luò)層數(shù)。在預(yù)測長序列時(shí),通過增加序列輸入長度,取得了更好的預(yù)測結(jié)果,這是因?yàn)楦L的序列,蘊(yùn)含了更豐富的特征。

04 消融實(shí)驗(yàn)

消融實(shí)驗(yàn)主要針對(duì)本文的三個(gè)創(chuàng)新點(diǎn)進(jìn)行測試:ProbSparse 自注意力機(jī)制、自注意力蒸餾、生成式解碼器。

實(shí)驗(yàn)表明,三種創(chuàng)新機(jī)制對(duì)較長的序列預(yù)測任務(wù)均有正面影響。

6. 總結(jié)

回顧一下,本文主要突出以下幾點(diǎn):

a. 長序列預(yù)測問題是一個(gè)基礎(chǔ)且現(xiàn)實(shí)需求高的問題,值得重點(diǎn)研究。

b. 基于自注意力機(jī)制的模型在解決序列預(yù)測問題上是一個(gè)很有吸引力的模型

?- ProbSparse 自注意力機(jī)制有助于降低attention計(jì)算過程中的時(shí)間復(fù)雜度

- self-attention蒸餾機(jī)制可以有效減少特征圖的時(shí)間維度,從而降低內(nèi)存消耗量

- 并行生成式解碼器機(jī)制能夠大幅提高長序列預(yù)測的效率

c. Query的注意力稀疏假設(shè)在自注意力網(wǎng)絡(luò)結(jié)構(gòu)中是行之有效的

本文提出的方法可以在其他領(lǐng)域取得同樣良好的效果,如:長序列文本、音樂、視頻、圖像的生成任務(wù)。

Informer能獲得AAAI 2021的Best Paper確實(shí)有很多值得肯定的地方。首先,在閱讀體驗(yàn)上,我們可以很好地順著作者的邏輯結(jié)構(gòu)了解到本文的研究動(dòng)機(jī)、方法及實(shí)驗(yàn)結(jié)果,由此可見,不論什么領(lǐng)域,講故事的能力絕對(duì)重要。此外,實(shí)驗(yàn)部分比較充實(shí),能夠很好地圍繞全文一直提及的長序列預(yù)測的難點(diǎn)以及Transformer應(yīng)用到長序列預(yù)測上需要解決的問題。最后,作者針對(duì)傳統(tǒng)attention網(wǎng)絡(luò)結(jié)構(gòu)中的query稀疏性進(jìn)行了驗(yàn)證,為今后attention研究提供了明確的思路與方向,值得學(xué)習(xí)!


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容