生成式大模型的RLHF技術(shù)(一):基礎(chǔ)

一、概述

大語言模型(LLMs)在預(yù)訓(xùn)練的過程中通常會捕捉數(shù)據(jù)的特征,而這些訓(xùn)練數(shù)據(jù)通常既包含高質(zhì)量的也包含低質(zhì)量的,因此模型有時會產(chǎn)生不被期望的行為,如編造事實,生成有偏見或有毒的文本,甚至對人類有害的內(nèi)容。因此,將LLMs與人類價值觀(如helpful, honest, 和harmless, 即3H)對齊是非常重要的,目前采用的主流的技術(shù)即是基于人類反饋的強化學(xué)習(xí)技術(shù)(RLHF)。

通常來說,RLHF包括三個步驟:
①supervised fine-tuning (SFT):對LLMs進行微調(diào),LLMs通過模仿人類標注的對話示例來學(xué)習(xí)通用的的類似人類的對話。
②reward model (RM) training:對于模型對同一個prompt的多個回復(fù),利用人類標注來進行排序以獲取人類偏好,然后單獨使用另一個語言模型作為reward model,在這個reward model上使用標注的數(shù)據(jù)進行訓(xùn)練(類似排序任務(wù))。
③proximal policy optimization (PPO):以訓(xùn)練得到的reward model作為reward function來繼續(xù)訓(xùn)練優(yōu)化LLMs,促進其與人類偏好的對齊。

RLHF

RLHF的整個過程如上圖所示,也可參考OpenAI的InstructGPT的做法:InstructGPT:語言模型的人類反饋指令對齊。InstructGPT使用的數(shù)據(jù)規(guī)模如下表所示:

數(shù)據(jù)規(guī)模

本文會介紹基于PPO(proximal policy optimization)的RLHF的技術(shù)細節(jié),包括Reward Modeling、Policy Gradient Methods、Advantage Actor-Critic (A2C)等。本文主要介紹A2C框架,其中包括四個模型的調(diào)度:
①Policy Model/Actor Model:由SFT之后的模型初始化而來。作為策略(policy)模型,用于接收上文,做出動作,預(yù)測下一個字符。學(xué)習(xí)完畢之后,我們最終使用的就是這個模型。
②Reference Model:和Actor Model同樣初始化自SFT Model,訓(xùn)練過程中凍結(jié)參數(shù),用于和Actor Model做對比,保證模型不要偏離原始SFT Model太多。
③Reward Model:作為環(huán)境(env),訓(xùn)練過程中凍結(jié)參數(shù),針對每一個狀態(tài),給出獎勵分數(shù)。
④Critic Model:由Reward Model初始化而來,用于近似價值函數(shù),輸入為當前的狀態(tài),估計當前狀態(tài)的價值。

二、Reward Modeling

Reward model可以使用移除了最后一個unembedding層的預(yù)訓(xùn)練語言模型來作為基礎(chǔ)架構(gòu),通常就是將最后一個token最終的embedding輸入給一個線性層,然后得到一個標量值,即是reward的值。在InstructGPT中,作者嘗試用了1.3B、6B和175B的GPT-3來做實驗,最終綜合考慮只用6B的模型來訓(xùn)練reward model。在訓(xùn)練reward model時,對于同一個輸入promptx,有一個更偏好的輸出y_w和一個相對不偏好的輸出y_l。每一對偏好和不偏好的回復(fù)的損失為:

\mathcal{L}(\psi )=log\, \sigma (r(x,y_{w})-r(x,y_{l}))

這里的\sigma是sigmoid函數(shù),r代表reward model,其參數(shù)為\psi,r(x,y)是一個reward model為xy預(yù)測的標量得分。另外,也可以額外加上一個模仿學(xué)習(xí)(imitation learning)的損失,也就是一個語言模型的預(yù)訓(xùn)練損失,來讓模型模仿句子對中更偏好的那一個:

\mathcal{L}(\psi )=-\lambda \mathbb{E}_{(x,y_w,y_l)\backsim \mathcal{D}_{rm}}[log\, \sigma (r(x,y_{w})-r(x,y_{l}))]+\beta _{rm}\mathbb{E}_{(x,y_w)\backsim \mathcal{D}_{rm}}[log(r^{\prime}(x,y_w))]

這里的\lambda\beta _{rm}都是超參數(shù),\mathcal{D}_{rm}是訓(xùn)練集的經(jīng)驗分布,r^{\prime}r除了頂部線性層不同以外是同一個模型(r^{\prime}線性層的維度為詞典的大小),r^{\prime}(x,y_w)是給定promptx和偏好回復(fù)y_w后的似然。

在PPO階段,使用訓(xùn)練得到的reward model來作為reward function訓(xùn)練policy模型時,還可以為reward function添加一個基于當前policy模型\pi _{\phi }^{RL}和SFT模型\pi ^{SFT}之間KL散度的懲罰項,此時reward function為:

r_{total}=r(x,y)-\eta KL(\pi _{\phi }^{RL}(y|x),\pi ^{SFT}(y|x))

這里的\eta是一個超參數(shù)。這個KL散度的懲罰項,通常來說有兩個作用:
①作為一個entropy bonus,促進模型在policy空間中的探索,防止policy model過早地收斂到一個單一的模式。
②可以確保強化學(xué)習(xí)policy model的輸出不會與reward model在訓(xùn)練階段遇到的樣本嚴重偏離。

三、Reinforcement Learning

將強化學(xué)習(xí)應(yīng)用于對話生成是一種艱難的挑戰(zhàn),這是因為其狀態(tài)-動作空間(state-action space)是非常巨大和廣闊的。在這樣的場景中,我們將人類的互動作為環(huán)境(environment)。在每個時間步t,agent(也就是LLMs)將從環(huán)境(即對話歷史)中接收一個狀態(tài)s_t,其中包括所有的對話歷史文本(也就是prompt和已經(jīng)生成的回復(fù))。接著,基于policy\pi,agent的動作a_t為生成下一個token。環(huán)境相應(yīng)的也會反饋一個rewardr(s_t,a_t),這個reward來自于一個reward functionrr是從人類偏好數(shù)據(jù)中訓(xùn)練得到的,也就是前文的reward model。此后,agent將轉(zhuǎn)換到下一個狀態(tài)s_{t+1}。整個過程將得到一個軌跡(trajectory)\tau =\left \{s_1,a_1,s_2,a_2,\dots ,s_T,a_T\right \}。對于LLMs的一個輸入x和輸出y來說,s_1=x,a_i=y_i,可采取的動作的所有選擇即是模型詞典中的所有token。強化學(xué)習(xí)的目的即是最大化一個軌跡的累積reward(也就是回報,return)。一種有限期無折扣回報( finite-horizon undiscounted return)為R(\tau )=\textstyle\sum_{t=1}^{T^{\prime}}r(s_t,a_t),即有限步數(shù)的累積reward的加和。另一種無限期折扣回報( infinite-horizon discounted return)為R(\tau )=\textstyle\sum_{t=0}^{\infty }\gamma ^{t}r(s_t,a_t),這種計算方式考慮了整個軌跡上所獲得的的所有回報,其中\gamma \in (0,1)為折扣率。

  1. Policy Gradient Methods

在介紹最常用的A2C方法之前,先介紹下更基礎(chǔ)一些的policy gradient方法。Policy gradient方法是一種強化學(xué)習(xí)技術(shù),其直接優(yōu)化agent的policy,也就是從state到action的映射。Policy gradient方法的核心思想是使用梯度上升方法直接優(yōu)化policy。從本質(zhì)上講,這類方法調(diào)整policy model的參數(shù),使其朝著最大限度地提高預(yù)期return的方向優(yōu)化。Policy\pi通常由\theta參數(shù)化,我們將其表示為\pi (a|s,\theta),即在狀態(tài)s下采取動作a的概率。Policy gradient的參數(shù)更新方式為:

\theta \gets \theta +\alpha \nabla _{\theta }J(\theta )

這里的\alpha是學(xué)習(xí)率,J(\theta)表示當采用policy\pi _{\theta }時的期望return,其梯度\nabla _{\theta }J(\theta )被稱為policy gradient。一個policy gradient的通用形式為:

\nabla _{\theta }J(\theta )=\mathbb{E}_{\tau \sim \pi _{\theta }}\left [\sum_{t=0}^{T}\nabla _{\theta }log\; \pi _{\theta }(a_{t},s_{t})\Phi _{t}\right ]

這里的\Phi _{t}可以是\Phi _{t}=R(\tau), \Phi _{t}=\textstyle\sum_{t^{\prime}=t}^{T}R(s_{t^{\prime}},a_{t^{\prime}}),\Phi _{t}=\textstyle\sum_{t^{\prime}=t}^{T}R(s_{t^{\prime}},a_{t^{\prime}})-b(s_{t})b是一個baseline)中的任意一個。這些不同的形式會得到policy gradient相同的期望值,但會得到不同的方差。

Return通過蒙特卡洛采樣的方式來計算。如果得到的return是有利的,則會增大生成這些動作的概率。這種方法的優(yōu)勢在于其是無偏的,因為其只依賴于實際的return,而非需要估計它。然而這種方法具有非常高的方差,因為同一個prompt產(chǎn)生的不同的軌跡會計算得到不同的return值,這是由于環(huán)境(一個episode中的隨機事件)和policy本身的隨機性決定的。

為了降低這里的高方差,一種常用的策略是使用優(yōu)勢函數(shù)(advantage function)估計來替代原始return,即\Phi _{t}=A(s_t,a_t)。優(yōu)勢函數(shù)A(s_t,a_t)表示在狀態(tài)s_t時采取動作a_t相比于同一狀態(tài)下所有動作的平均質(zhì)量來說是否會更好。這里需要引入兩個概念:
①動作的價值Q(s_t,a_t):在t時刻,給定當前狀態(tài)s_t,采取動作a_t,可以獲得的return的期望,具體的,Q(s_t,a_t)=\mathbb{E}\left [\textstyle\sum _{t^{\prime}=t}^{T}r(s_{t^{\prime}},a_{t^{\prime}})|s_t,a_t\right ]
②狀態(tài)的價值V(s_t):在t時刻,給定當前狀態(tài)s_t,可以獲得的return的期望,具體的,V(s_t)=\mathbb{E}_{a\in \mathcal{A}}\left [Q(s_t,a)\right ]=\textstyle\sum _{a}\pi _{\theta }(a|s_t)Q(s_t,a),其中\mathcal{A}是所有動作的集合。這里就是在上文s_t狀態(tài)下,求和所有下一個token出現(xiàn)概率與token對應(yīng)價值的乘積。

有了這兩個概念即可得到優(yōu)勢函數(shù)A(s_{t},a_{t})=Q(s_{t},a_{t})-V(s_{t})。優(yōu)勢函數(shù)的意義在于只有當前動作的return比平均水平更高時才能獲得正的優(yōu)勢值,從而被增強,如果低于平均水平就會被抑制。

使用優(yōu)勢函數(shù)的policy gradient方法是強化學(xué)習(xí)鄰域的重要支柱。優(yōu)勢函數(shù)的估計方法是多種多樣的,其中有一種廣泛采用的估計方法為廣義優(yōu)勢估計(Generalized Advantage Estimation, GAE),下一節(jié)將著重介紹。

  1. Generalized Advantage Estimation

優(yōu)勢函數(shù)A定義為Q函數(shù)與價值函數(shù)的差值。Q考慮一個具體的動作,而價值函數(shù)是所有可能的動作的平均。而在實踐中,我們使用實際episode的return來估計Q函數(shù),也就是蒙特卡洛采樣的return,這會引入非常高的方差,因為未來的reward包含大量的噪聲。一種減少這種噪聲的方法是使用價值函數(shù)來估計未來(時間步t之后)的return,也就是Temporal Difference (TD)的方法,這種方法通常有較大的偏差。本節(jié)要介紹的GAE算法是一種介于使用one-step TD return(高偏差)和完全蒙特卡洛return(高方差)之間的算法,可以平衡偏差和方差。接下來的內(nèi)容即是對GAE算法的推導(dǎo)。

首先,我們用\hat{R}_{t}^{k}來表示TD-k return,這是一種實際的reward和估計的return的集合:

\hat{R}_{t}^{k}=r_{t}+\gamma r_{t+1}+\cdots +\gamma ^{(k-1)}r_{t+k-1}+\gamma ^kV(s_{t+k})

這里的\gamma是折扣率。折扣率的意義可以這樣理解:每一步雖然都有一個即時的reward,但是每一步對后面的可能狀態(tài)都是有影響的,即后面的動作獲取的reward都能累計到前面的動作的貢獻。不過直接加上去可能不好,畢竟不是前面的動作直接獲取的reward,但是可以打個折扣再加上去,即乘個小于1的\gamma

使用TD-k return的優(yōu)勢稱為k-step優(yōu)勢,定義為:

\hat{A}_t^k=\hat{R}_t^k-V(s_t)=-V(s_t)+r_t+\gamma r_{t+1}+\cdots +\gamma ^{(k-1)}r_{t+k-1}+\gamma ^kV(s_{t+k})=\sum_{l=1}^{k}\gamma ^{l}\delta _{t+l}

這里的\delta _{t}=r_t+\gamma V(s_{t+1})-V(s_t),叫做TD error。在k-step優(yōu)勢中,如果k比較小,偏差會很高,因為優(yōu)勢估計只基于很少的步數(shù),所以非常依賴價值函數(shù)的準確性。而在k比較大時,方差會非常高,因為優(yōu)勢估計會把很多噪聲reward加進來。

為了平衡偏差和方差,GAE定義優(yōu)勢函數(shù)為k-step優(yōu)勢的指數(shù)移動平均,權(quán)重為(1-\lambda )\lambda ^{(k-1)}

\begin{align} \hat{A}_{t}^{GAE(\gamma ,\lambda )}&=(1-\lambda )(\hat{A}_{t}^{(1)}+\lambda \hat{A}_{t}^{(2)}+\lambda ^{2}\hat{A}_{t}^{(3)}+\cdots )\\ &=(1-\lambda )(\delta _{t}+\lambda (\delta _{t}+\gamma \delta _{t+1})+\lambda ^2(\delta _{t}+\gamma \delta _{t+1}+\gamma ^2\delta _{t+2})+\cdots )\\ &=(1-\lambda )(\delta _{t}(1+\lambda +\lambda ^2+\cdots )+\gamma \delta _{t+1}(\lambda +\lambda ^2+\lambda ^3+\cdots )+\gamma ^2\delta _{t+2}(\lambda ^2+\lambda ^3+\lambda ^4+\cdots )+\cdots )\\ &=(1-\lambda )(\delta _{t}(\frac{1}{1-\lambda })+\gamma \delta _{t+1}(\frac{\lambda }{1-\lambda })+\gamma ^2\delta _{t+2}(\frac{\lambda ^2}{1-\lambda })+\cdots )\\ &=\sum_{l=0}^{\infty }(\gamma \lambda )^l\delta _{t+l} \end{align}

GAE可以平滑地平衡高偏差(\lambda =0)和高方差(\lambda =1):

GAE(\gamma ,0):\hat{A}_t=\delta _t=r_t+\gamma V(s_{t+1})-V(s_t)\\ GAE(\gamma ,1):\hat{A}_t=\sum_{l=0}^{\infty }\gamma ^l\delta _{t+l}=\sum_{l=0}^{\infty }\gamma ^lr _{t+l}-V(s_t)

通過GAE,我們可以準確地得到優(yōu)勢函數(shù)A(s_t,a_t)的估計值\hat{A}_t。這一估計值在進行policy gradient估計時將扮演重要的角色:

\nabla _{\theta }\hat{J}(\theta )=\frac{1}{|\mathcal{D}|}\sum_{\tau \in \mathcal{D}}\displaystyle\sum_{t=1}^{T}\nabla _{\theta }log\; \pi _{\theta }(a_t|s_t)\hat{A}_t

這里的\mathcal{D}是有限批量的樣本。后面我們將用\hat{\mathbb{E}}_t來表示\frac{1}{|\mathcal{D}|}\textstyle\sum_{\tau \in \mathcal{D}}\textstyle\sum_{t=1}^{T}

  1. Proximal Policy Optimization

PPO和TRPO是RL中的兩種關(guān)鍵技術(shù),旨在有效地訓(xùn)練policy而不損害其穩(wěn)定性。這些方法的遵循“小而穩(wěn)定的步驟”的思想,即輕微地進行policy的優(yōu)化,而不是強制進行可能破壞整個學(xué)習(xí)過程的激進更新。

在傳統(tǒng)的強化學(xué)習(xí)中,policy gradient的原則要求新舊policy在參數(shù)空間中保持接近。然而,參數(shù)空間中的這種接近并不一定等同于相似的性能,參數(shù)的輕微變化可能會極大地影響策略的有效性。這一部分原因要歸結(jié)于神經(jīng)網(wǎng)絡(luò)是過參數(shù)化的,類似微調(diào)語言模型的LoRA方法,無需微調(diào)所有的模型參數(shù)即可將模型適配到特定的下游任務(wù)上。此外,如果不加限制地大步更新,就可能導(dǎo)致policy表現(xiàn)崩潰,這種情況通常被描述為“掉下懸崖(falling off the cliff)”。這種固有的風險是原始的policy gradient中樣本效率的限制因素。

  • TRPO

TRPO的方法沒有受到參數(shù)接近性(parameter closeness)的限制,而是對policy更新引入了一種不同的約束。其通過確保新舊policy model的KL散度在一個可接受的限制范圍內(nèi)來正則化policy的更新:

\max_{\theta } \hat{\mathbb{E}}_t\left [\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}\hat{A}_t\right ],\\ \mathrm{subject\; to}\; \hat{\mathbb{E}}_t[KL(\pi _{\theta _{old}}(\cdot |s_t),\pi _{\theta }(\cdot |s_t))]\le \delta

這里的\theta _{old}是更新之前舊的policy參數(shù)。

  • PPO-Penalty

PPO的方法有兩個主要的變種:PPO-Penalty和PPO-Clip。TRPO是一種KL散度的硬性限制,而PPO-Penalty通過采用基于懲罰的方法而不是約束來以無約束優(yōu)化問題的方式優(yōu)化policy:

\mathcal{L}_{\mathrm{ppo-penalty}}(\theta )=\hat{\mathbb{E}}_t\left [\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}\hat{A}_t\right ]-\beta KL(\pi _{\theta _{old}}(\cdot |s_t),\pi _{\theta }(\cdot |s_t))

這里的\beta是懲罰因子。

  • PPO-Clip

PPO-Clip試圖保持新policy接近舊policy,但不像TRPO那樣對KL散度施加約束,而是在其目標函數(shù)中根據(jù)新舊policy的預(yù)測概率的比值進行截斷。目標函數(shù)可以表示為:

\mathcal{L}_{\mathrm{ppo-clip}}(\theta )=\hat{\mathbb{E}}_t\left [\mathrm{min}\left (\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}\hat{A}_t,\mathrm{clip}\left (\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)},1-\epsilon ,1+\epsilon \right )\hat{A}_t\right )\right ]

這里的\frac{\pi _{\theta }(a_t|s_t)}{\pi _{\theta _{old}}(a_t|s_t)}是新policy與舊policy的概率的比例,\epsilon是一個超參數(shù),用于控制新policy能夠偏離舊policy多遠。\mathrm{clip}函數(shù)將概率的比值限制在(1-\epsilon ,1+\epsilon )之間。\mathrm{clip}函數(shù)充當正則化器,限制policy從一個迭代到下一個迭代的急劇變化的程度。防止過大的policy更新,確保了學(xué)習(xí)過程的魯棒性,同時保持了比普通policy gradient方法更高效的樣本學(xué)習(xí)。

  • 價值函數(shù)估計

在PPO中,文章開篇提到的critic model通常用來作為價值函數(shù),來評估每個狀態(tài)的期望return。其學(xué)習(xí)目標為最小化其預(yù)測值與真實return之間的差異,目標函數(shù)通常采用MSE損失,即:

\mathcal{L}_{\mathrm{critic}}(\phi )=\hat{\mathbb{E}}_t\left [\left \|V_\phi (s_t)-\hat{R}_t\right \|^2\right ]

這里的V_\phi (s_t)代表critic model(參數(shù)為\phi)在s_t狀態(tài)的預(yù)測值,\hat{R}_t代表實際狀態(tài)s_t的return值,通常估計為:\hat{R}_t=\textstyle\sum_{l=0}^{\infty }\gamma ^{l}r_{t+l}

  • 混合預(yù)訓(xùn)練梯度

為了緩解PPO訓(xùn)練后的模型在通用語言能力上的退化和災(zāi)難性遺忘問題,可以在強化學(xué)習(xí)訓(xùn)練過程中加入預(yù)訓(xùn)練數(shù)據(jù),這種方法通常稱為PPO-ptx,其損失函數(shù)為:

\mathcal{L}_{\mathrm{ppo-ptx}}(\theta )=\mathcal{L}_{\mathrm{ppo-clip}}(\theta )+\lambda _{ptx}\mathbb{E}_{x\backsim \mathcal{D}_{pretrain}}[log(\pi _{\theta }^{RL}(x))]

這里的\lambda _{ptx}是一個超參數(shù),\mathcal{D}_{pretrain}是預(yù)訓(xùn)練數(shù)據(jù)的分布。

四、總結(jié)

綜合前面的描述,整個PPO訓(xùn)練過程的算法可以表示為:

算法

另外也可參考整個流程的框架圖:

框架

參考資料

Secrets of RLHF in Large Language Models Part I: PPO

大模型RLHF理論詳細講解

拆解大語言模型RLHF中的PPO

RLHF實踐

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 0. 引言 最近跟著 OpenAI 的 Spinning Up 教學(xué)文檔 學(xué)習(xí)了一遍 Deep RL,對這個領(lǐng)域有...
    OurNote閱讀 6,758評論 0 8
  • 簡介 2022年11月,OpenAI推出了一款A(yù)I聊天機器人程序,其強大的問答能力瞬間引爆全網(wǎng)關(guān)注度。 組成部分:...
    臻甄閱讀 1,998評論 0 0
  • 大家好,上期我們講到研發(fā)人員正在研究解決語言模型中的一致性問題。ChatGPT 使用了人類反饋來指導(dǎo)學(xué)習(xí)過程,對其...
    城北楠哥閱讀 671評論 0 0
  • 在公司看文檔,對用到的一些知識做簡單梳理;大部分idea來源于DeepMind或OpenAI PPO的目標函數(shù) P...
    YukiRain閱讀 789評論 0 0
  • 1. 背景 最近本qiang~老看到一些關(guān)于大語言模型的DPO、RLHF算法,但都有些云里霧里,因此靜下心來收集資...
    mengrennwpu閱讀 632評論 0 0

友情鏈接更多精彩內(nèi)容