一、概述
大語言模型(LLMs)在預(yù)訓(xùn)練的過程中通常會捕捉數(shù)據(jù)的特征,而這些訓(xùn)練數(shù)據(jù)通常既包含高質(zhì)量的也包含低質(zhì)量的,因此模型有時會產(chǎn)生不被期望的行為,如編造事實,生成有偏見或有毒的文本,甚至對人類有害的內(nèi)容。因此,將LLMs與人類價值觀(如helpful, honest, 和harmless, 即3H)對齊是非常重要的,目前采用的主流的技術(shù)即是基于人類反饋的強化學(xué)習(xí)技術(shù)(RLHF)。
通常來說,RLHF包括三個步驟:
①supervised fine-tuning (SFT):對LLMs進行微調(diào),LLMs通過模仿人類標注的對話示例來學(xué)習(xí)通用的的類似人類的對話。
②reward model (RM) training:對于模型對同一個prompt的多個回復(fù),利用人類標注來進行排序以獲取人類偏好,然后單獨使用另一個語言模型作為reward model,在這個reward model上使用標注的數(shù)據(jù)進行訓(xùn)練(類似排序任務(wù))。
③proximal policy optimization (PPO):以訓(xùn)練得到的reward model作為reward function來繼續(xù)訓(xùn)練優(yōu)化LLMs,促進其與人類偏好的對齊。

RLHF的整個過程如上圖所示,也可參考OpenAI的InstructGPT的做法:InstructGPT:語言模型的人類反饋指令對齊。InstructGPT使用的數(shù)據(jù)規(guī)模如下表所示:

本文會介紹基于PPO(proximal policy optimization)的RLHF的技術(shù)細節(jié),包括Reward Modeling、Policy Gradient Methods、Advantage Actor-Critic (A2C)等。本文主要介紹A2C框架,其中包括四個模型的調(diào)度:
①Policy Model/Actor Model:由SFT之后的模型初始化而來。作為策略(policy)模型,用于接收上文,做出動作,預(yù)測下一個字符。學(xué)習(xí)完畢之后,我們最終使用的就是這個模型。
②Reference Model:和Actor Model同樣初始化自SFT Model,訓(xùn)練過程中凍結(jié)參數(shù),用于和Actor Model做對比,保證模型不要偏離原始SFT Model太多。
③Reward Model:作為環(huán)境(env),訓(xùn)練過程中凍結(jié)參數(shù),針對每一個狀態(tài),給出獎勵分數(shù)。
④Critic Model:由Reward Model初始化而來,用于近似價值函數(shù),輸入為當前的狀態(tài),估計當前狀態(tài)的價值。
二、Reward Modeling
Reward model可以使用移除了最后一個unembedding層的預(yù)訓(xùn)練語言模型來作為基礎(chǔ)架構(gòu),通常就是將最后一個token最終的embedding輸入給一個線性層,然后得到一個標量值,即是reward的值。在InstructGPT中,作者嘗試用了1.3B、6B和175B的GPT-3來做實驗,最終綜合考慮只用6B的模型來訓(xùn)練reward model。在訓(xùn)練reward model時,對于同一個輸入prompt,有一個更偏好的輸出
和一個相對不偏好的輸出
。每一對偏好和不偏好的回復(fù)的損失為:
這里的是sigmoid函數(shù),
代表reward model,其參數(shù)為
,
是一個reward model為
和
預(yù)測的標量得分。另外,也可以額外加上一個模仿學(xué)習(xí)(imitation learning)的損失,也就是一個語言模型的預(yù)訓(xùn)練損失,來讓模型模仿句子對中更偏好的那一個:
這里的和
都是超參數(shù),
是訓(xùn)練集的經(jīng)驗分布,
與
除了頂部線性層不同以外是同一個模型(
線性層的維度為詞典的大小),
是給定prompt
和偏好回復(fù)
后的似然。
在PPO階段,使用訓(xùn)練得到的reward model來作為reward function訓(xùn)練policy模型時,還可以為reward function添加一個基于當前policy模型和SFT模型
之間KL散度的懲罰項,此時reward function為:
這里的是一個超參數(shù)。這個KL散度的懲罰項,通常來說有兩個作用:
①作為一個entropy bonus,促進模型在policy空間中的探索,防止policy model過早地收斂到一個單一的模式。
②可以確保強化學(xué)習(xí)policy model的輸出不會與reward model在訓(xùn)練階段遇到的樣本嚴重偏離。
三、Reinforcement Learning
將強化學(xué)習(xí)應(yīng)用于對話生成是一種艱難的挑戰(zhàn),這是因為其狀態(tài)-動作空間(state-action space)是非常巨大和廣闊的。在這樣的場景中,我們將人類的互動作為環(huán)境(environment)。在每個時間步,agent(也就是LLMs)將從環(huán)境(即對話歷史)中接收一個狀態(tài)
,其中包括所有的對話歷史文本(也就是prompt和已經(jīng)生成的回復(fù))。接著,基于policy
,agent的動作
為生成下一個token。環(huán)境相應(yīng)的也會反饋一個reward
,這個reward來自于一個reward function
,
是從人類偏好數(shù)據(jù)中訓(xùn)練得到的,也就是前文的reward model。此后,agent將轉(zhuǎn)換到下一個狀態(tài)
。整個過程將得到一個軌跡(trajectory)
。對于LLMs的一個輸入
和輸出
來說,
,
,可采取的動作的所有選擇即是模型詞典中的所有token。強化學(xué)習(xí)的目的即是最大化一個軌跡的累積reward(也就是回報,return)。一種有限期無折扣回報( finite-horizon undiscounted return)為
,即有限步數(shù)的累積reward的加和。另一種無限期折扣回報( infinite-horizon discounted return)為
,這種計算方式考慮了整個軌跡上所獲得的的所有回報,其中
為折扣率。
- Policy Gradient Methods
在介紹最常用的A2C方法之前,先介紹下更基礎(chǔ)一些的policy gradient方法。Policy gradient方法是一種強化學(xué)習(xí)技術(shù),其直接優(yōu)化agent的policy,也就是從state到action的映射。Policy gradient方法的核心思想是使用梯度上升方法直接優(yōu)化policy。從本質(zhì)上講,這類方法調(diào)整policy model的參數(shù),使其朝著最大限度地提高預(yù)期return的方向優(yōu)化。Policy通常由
參數(shù)化,我們將其表示為
,即在狀態(tài)
下采取動作
的概率。Policy gradient的參數(shù)更新方式為:
這里的是學(xué)習(xí)率,
表示當采用policy
時的期望return,其梯度
被稱為policy gradient。一個policy gradient的通用形式為:
這里的可以是
,
,
(
是一個baseline)中的任意一個。這些不同的形式會得到policy gradient相同的期望值,但會得到不同的方差。
Return通過蒙特卡洛采樣的方式來計算。如果得到的return是有利的,則會增大生成這些動作的概率。這種方法的優(yōu)勢在于其是無偏的,因為其只依賴于實際的return,而非需要估計它。然而這種方法具有非常高的方差,因為同一個prompt產(chǎn)生的不同的軌跡會計算得到不同的return值,這是由于環(huán)境(一個episode中的隨機事件)和policy本身的隨機性決定的。
為了降低這里的高方差,一種常用的策略是使用優(yōu)勢函數(shù)(advantage function)估計來替代原始return,即。優(yōu)勢函數(shù)
表示在狀態(tài)
時采取動作
相比于同一狀態(tài)下所有動作的平均質(zhì)量來說是否會更好。這里需要引入兩個概念:
①動作的價值:在
時刻,給定當前狀態(tài)
,采取動作
,可以獲得的return的期望,具體的,
。
②狀態(tài)的價值:在
時刻,給定當前狀態(tài)
,可以獲得的return的期望,具體的,
,其中
是所有動作的集合。這里就是在上文
狀態(tài)下,求和所有下一個token出現(xiàn)概率與token對應(yīng)價值的乘積。
有了這兩個概念即可得到優(yōu)勢函數(shù)。優(yōu)勢函數(shù)的意義在于只有當前動作的return比平均水平更高時才能獲得正的優(yōu)勢值,從而被增強,如果低于平均水平就會被抑制。
使用優(yōu)勢函數(shù)的policy gradient方法是強化學(xué)習(xí)鄰域的重要支柱。優(yōu)勢函數(shù)的估計方法是多種多樣的,其中有一種廣泛采用的估計方法為廣義優(yōu)勢估計(Generalized Advantage Estimation, GAE),下一節(jié)將著重介紹。
- Generalized Advantage Estimation
優(yōu)勢函數(shù)定義為
函數(shù)與價值函數(shù)的差值。
考慮一個具體的動作,而價值函數(shù)是所有可能的動作的平均。而在實踐中,我們使用實際episode的return來估計
函數(shù),也就是蒙特卡洛采樣的return,這會引入非常高的方差,因為未來的reward包含大量的噪聲。一種減少這種噪聲的方法是使用價值函數(shù)來估計未來(時間步
之后)的return,也就是Temporal Difference (TD)的方法,這種方法通常有較大的偏差。本節(jié)要介紹的GAE算法是一種介于使用one-step TD return(高偏差)和完全蒙特卡洛return(高方差)之間的算法,可以平衡偏差和方差。接下來的內(nèi)容即是對GAE算法的推導(dǎo)。
首先,我們用來表示TD-
return,這是一種實際的reward和估計的return的集合:
這里的是折扣率。折扣率的意義可以這樣理解:每一步雖然都有一個即時的reward,但是每一步對后面的可能狀態(tài)都是有影響的,即后面的動作獲取的reward都能累計到前面的動作的貢獻。不過直接加上去可能不好,畢竟不是前面的動作直接獲取的reward,但是可以打個折扣再加上去,即乘個小于1的
。
使用TD- return的優(yōu)勢稱為
-step優(yōu)勢,定義為:
這里的,叫做TD error。在
-step優(yōu)勢中,如果
比較小,偏差會很高,因為優(yōu)勢估計只基于很少的步數(shù),所以非常依賴價值函數(shù)的準確性。而在
比較大時,方差會非常高,因為優(yōu)勢估計會把很多噪聲reward加進來。
為了平衡偏差和方差,GAE定義優(yōu)勢函數(shù)為-step優(yōu)勢的指數(shù)移動平均,權(quán)重為
:
GAE可以平滑地平衡高偏差()和高方差(
):
通過GAE,我們可以準確地得到優(yōu)勢函數(shù)的估計值
。這一估計值在進行policy gradient估計時將扮演重要的角色:
這里的是有限批量的樣本。后面我們將用
來表示
。
- Proximal Policy Optimization
PPO和TRPO是RL中的兩種關(guān)鍵技術(shù),旨在有效地訓(xùn)練policy而不損害其穩(wěn)定性。這些方法的遵循“小而穩(wěn)定的步驟”的思想,即輕微地進行policy的優(yōu)化,而不是強制進行可能破壞整個學(xué)習(xí)過程的激進更新。
在傳統(tǒng)的強化學(xué)習(xí)中,policy gradient的原則要求新舊policy在參數(shù)空間中保持接近。然而,參數(shù)空間中的這種接近并不一定等同于相似的性能,參數(shù)的輕微變化可能會極大地影響策略的有效性。這一部分原因要歸結(jié)于神經(jīng)網(wǎng)絡(luò)是過參數(shù)化的,類似微調(diào)語言模型的LoRA方法,無需微調(diào)所有的模型參數(shù)即可將模型適配到特定的下游任務(wù)上。此外,如果不加限制地大步更新,就可能導(dǎo)致policy表現(xiàn)崩潰,這種情況通常被描述為“掉下懸崖(falling off the cliff)”。這種固有的風險是原始的policy gradient中樣本效率的限制因素。
- TRPO
TRPO的方法沒有受到參數(shù)接近性(parameter closeness)的限制,而是對policy更新引入了一種不同的約束。其通過確保新舊policy model的KL散度在一個可接受的限制范圍內(nèi)來正則化policy的更新:
這里的是更新之前舊的policy參數(shù)。
- PPO-Penalty
PPO的方法有兩個主要的變種:PPO-Penalty和PPO-Clip。TRPO是一種KL散度的硬性限制,而PPO-Penalty通過采用基于懲罰的方法而不是約束來以無約束優(yōu)化問題的方式優(yōu)化policy:
這里的是懲罰因子。
- PPO-Clip
PPO-Clip試圖保持新policy接近舊policy,但不像TRPO那樣對KL散度施加約束,而是在其目標函數(shù)中根據(jù)新舊policy的預(yù)測概率的比值進行截斷。目標函數(shù)可以表示為:
這里的是新policy與舊policy的概率的比例,
是一個超參數(shù),用于控制新policy能夠偏離舊policy多遠。
函數(shù)將概率的比值限制在
之間。
函數(shù)充當正則化器,限制policy從一個迭代到下一個迭代的急劇變化的程度。防止過大的policy更新,確保了學(xué)習(xí)過程的魯棒性,同時保持了比普通policy gradient方法更高效的樣本學(xué)習(xí)。
- 價值函數(shù)估計
在PPO中,文章開篇提到的critic model通常用來作為價值函數(shù),來評估每個狀態(tài)的期望return。其學(xué)習(xí)目標為最小化其預(yù)測值與真實return之間的差異,目標函數(shù)通常采用MSE損失,即:
這里的代表critic model(參數(shù)為
)在
狀態(tài)的預(yù)測值,
代表實際狀態(tài)
的return值,通常估計為:
。
- 混合預(yù)訓(xùn)練梯度
為了緩解PPO訓(xùn)練后的模型在通用語言能力上的退化和災(zāi)難性遺忘問題,可以在強化學(xué)習(xí)訓(xùn)練過程中加入預(yù)訓(xùn)練數(shù)據(jù),這種方法通常稱為PPO-ptx,其損失函數(shù)為:
這里的是一個超參數(shù),
是預(yù)訓(xùn)練數(shù)據(jù)的分布。
四、總結(jié)
綜合前面的描述,整個PPO訓(xùn)練過程的算法可以表示為:

另外也可參考整個流程的框架圖:
