如何快速理解線性回歸及模型訓(xùn)練?

在機器學(xué)習(xí)領(lǐng)域內(nèi),經(jīng)常會聽到回歸問題、線性回歸等專業(yè)術(shù)語,那什么是回歸問題,線性回歸又是什么意思以及怎么訓(xùn)練一個模型、怎么調(diào)整評估模型的預(yù)測性能呢,本文將結(jié)合網(wǎng)上資料及TensorFlow來訓(xùn)練一個最簡單的線性模型并評估其性能優(yōu)劣。

01

線性回歸

什么是回歸問題

回歸分析是確定多個相互依賴的變量之間的定量關(guān)系的方法?;貧w分析通常用于預(yù)測,發(fā)現(xiàn)變量之間的因果關(guān)系。

線性回歸

線性回歸是回歸分析的一種,它研究的是多個變量之間的線性組合,線性回歸假設(shè)的是預(yù)測的目標和特征之間是線性相關(guān),通??梢允褂靡韵鹿絹肀磉_:

image

其中w表示的是特征x的權(quán)重,b表示的是偏差。在幾何中,w又表示直線的斜率,w表示截距。

線性回歸的目的在于,通過大量的訓(xùn)練數(shù)據(jù)來找到最合適的w,b的值。機器學(xué)習(xí)中常見的專業(yè)術(shù)語“模型”,就是上述類似的一組描述特征和目標之間關(guān)系的公式。

機器學(xué)習(xí)訓(xùn)練出來的模型,最終需要應(yīng)用的到驗證數(shù)據(jù)上進行預(yù)測,而模型的優(yōu)劣好壞如何度量,如何才能找到合適的w,b的值呢?這就需要定量分析,需要引入“損失”、“損失函數(shù)”等概念了。

02

損失函數(shù)

我們先看下面的一個例子,紅色的直線表示當(dāng)前的學(xué)習(xí)模型,藍色的點表示測試數(shù)據(jù),通過下圖我們可以看到只有一個點跟模型擬合,其他的都欠擬合。從直覺上我們知道,當(dāng)前的模型可能沒那么好,那么如何定量的去衡量這個模型的好壞呢?

image

我們?nèi)匀灰陨蠄D為例,取x = 12這個樣本值,那么相應(yīng)藍色點的y值即為真實值,而相應(yīng)紅線所在的prediction(x)即表示預(yù)測值。那么y - prediction(x)即為損失。那么衡量這種損失的函數(shù),我們稱之為損失函數(shù)。

目前有這么幾種損失函數(shù):

  • 平方損失,又稱L2損失

  • 殘差平方和(RSS, Residual of Sum of Square)

  • 均方誤差(MSE, Mean Square Error)

  • 均方根誤差(RMSE, Root Mean Square Error)

其實都很好理解,我們一個一個來看。

平方損失

平方損失最好理解,單個樣本的平方損失,可以有如下公式給出:

image

其實即為真實值與預(yù)測值之間差值的平方。

******殘差平方和******

殘差平方和指的是將單個樣本的真實值與預(yù)測值之差的平方相加求和:

image

也即所有樣本的平方損失之和。

****均方誤差****

均方誤差指的是平均平方損失,也即需要計算所有樣本的平方損失之和,然后再除以樣本之和??梢杂扇缦鹿浇o出:

image

其中(x,y)表示樣本,x指的是樣本的特征值,y指的是樣本的標簽(或者說目標值)。

D表示數(shù)據(jù)集,prediction(x)指的是模型預(yù)測值,N表示樣本個數(shù)。

******均方根誤差******

均方根誤差與均方誤差,從字面意思上來講其實有很大的聯(lián)系,均方根誤差即為均方誤差的開平方根。

image

在實際編程應(yīng)用中,經(jīng)常會先求出MSE,然后對其開平方即可得到RMSE。

了解完所有的損失函數(shù)之后,我們不禁有個疑問,這些誤差到底在機器學(xué)習(xí)中起到了什么作用?

文章的開頭我們指出,我們需要利用損失函數(shù)來評估模型的優(yōu)劣,或者說我們通過損失函數(shù)的取值來調(diào)整模型的參數(shù),以期獲得最合適的模型。既然損失函數(shù)描述的是預(yù)測值與真實值的偏離情況,那么在實際訓(xùn)練過程中,我們應(yīng)該找到使損失函數(shù)取最小值時取對應(yīng)模型的參數(shù)。

我們假設(shè)目標和特征之間存在線性相關(guān)性,于是就有公式 y' = wx + b (假設(shè)只有一個特征)。以MSE(均方誤差)及線性回歸為例,我們將公式帶入MSE中替換預(yù)測值,即可得如下公式:

image

該公式是以w,b為參數(shù)的函數(shù),我們的目標是求取w,b的值使得該函數(shù)取最小值。由于N是一個常數(shù),我們可以簡化成下面的公式:

image

機器學(xué)習(xí)的模型訓(xùn)練問題到此就被我們簡化成了求解上面公式(2)的最小值,求解函數(shù)最小值的方法有很多,比較常見的有:

  • 最小二乘法

  • 梯度下降法

最小二乘法我們這里省略不講解,本文只講解梯度下降法。

03

迭代法與梯度下降

講解梯度下降法之前,我們先了解一下機器學(xué)習(xí)訓(xùn)練過程中的迭代法。

在機器學(xué)習(xí)過程中,我們的會選取一個初始化參數(shù)(w,b),然后通過損失函數(shù)來不斷地調(diào)整參數(shù)值(w,b)直到訓(xùn)練處一個符合預(yù)期的模型。下圖是一個簡化版本的機器學(xué)習(xí)迭代過程:

image

圖例來自Google機器學(xué)習(xí)文檔

上圖中,“計算參數(shù)更新”就是梯度下降法的本質(zhì)。在回歸問題中,權(quán)重值w的取值和損失之間的關(guān)系如下圖,始終是一個凸函數(shù)(凸函數(shù),用幾何的方式來理解就是圖形內(nèi)任意兩點構(gòu)成的線段仍然被圖形包圍)。也就是說找到了一個極小值點,那么這個極小值點就是全局最小值點。

image

圖例來自Google機器學(xué)習(xí)文檔

上圖中,任意選取一個起點。由于我們能夠看到圖形全貌,我們很快就能找到損失下降的方向,并最終定位到極小值點。但是在實際應(yīng)用中,我們?nèi)我膺x取的起點后,是不能立刻確定哪個方向是損失下降的方向。舉個例子,常識告訴我們地球是圓的,但站著地球上,我們始終感受不到弧度,也就無從找到某段圓弧的最低點。同樣,我們就需要通過某種方式判斷哪個方向是損失下降的方向。

在熟知的高數(shù)中,我們知道導(dǎo)數(shù)可以理解成斜率,或者曲線某個點的導(dǎo)數(shù)可以理解成改點切線的斜率。導(dǎo)數(shù)為正,則曲線遞增;反之,曲線遞減。梯度就可以理解成導(dǎo)數(shù)(偏導(dǎo)數(shù))。

梯度下降就是按照負梯度的方向,逐漸的去找到最小值點。

04

模型訓(xùn)練

了解了上面的線性回歸的基本原理以及模型訓(xùn)練的目的之后,我們總結(jié)一下線性回歸模型訓(xùn)練的套路:

  • 分析訓(xùn)練數(shù)據(jù),找到數(shù)據(jù)的合適的分布

  • 確定問題的類型,回歸還是分類問題

  • 加載數(shù)據(jù)集,輸入訓(xùn)練數(shù)據(jù)

  • 開始訓(xùn)練模型,迭代過程

  • 使用測試數(shù)據(jù)來分析當(dāng)前模型的優(yōu)劣

  • 其他(后面再講)

可以參照google機器學(xué)習(xí)文檔來詳細了解如何使用tensorflow訓(xùn)練模型。接下來會寫一篇TensorFlow模型訓(xùn)練初體驗。

** 結(jié)語 **

線性回歸是機器學(xué)習(xí)中常見的一種類型,通過損失函數(shù)來反復(fù)迭代機器學(xué)習(xí)模型。在迭代過程中采用了梯度下降法,尋找最優(yōu)解,整個流程比較容易理解。

關(guān)注公眾賬號<歲與禾>,獲取最新信息

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容