在機器學(xué)習(xí)領(lǐng)域內(nèi),經(jīng)常會聽到回歸問題、線性回歸等專業(yè)術(shù)語,那什么是回歸問題,線性回歸又是什么意思以及怎么訓(xùn)練一個模型、怎么調(diào)整評估模型的預(yù)測性能呢,本文將結(jié)合網(wǎng)上資料及TensorFlow來訓(xùn)練一個最簡單的線性模型并評估其性能優(yōu)劣。
01
線性回歸
什么是回歸問題
回歸分析是確定多個相互依賴的變量之間的定量關(guān)系的方法?;貧w分析通常用于預(yù)測,發(fā)現(xiàn)變量之間的因果關(guān)系。
線性回歸
線性回歸是回歸分析的一種,它研究的是多個變量之間的線性組合,線性回歸假設(shè)的是預(yù)測的目標和特征之間是線性相關(guān),通??梢允褂靡韵鹿絹肀磉_:
其中w表示的是特征x的權(quán)重,b表示的是偏差。在幾何中,w又表示直線的斜率,w表示截距。
線性回歸的目的在于,通過大量的訓(xùn)練數(shù)據(jù)來找到最合適的w,b的值。機器學(xué)習(xí)中常見的專業(yè)術(shù)語“模型”,就是上述類似的一組描述特征和目標之間關(guān)系的公式。
機器學(xué)習(xí)訓(xùn)練出來的模型,最終需要應(yīng)用的到驗證數(shù)據(jù)上進行預(yù)測,而模型的優(yōu)劣好壞如何度量,如何才能找到合適的w,b的值呢?這就需要定量分析,需要引入“損失”、“損失函數(shù)”等概念了。
02
損失函數(shù)
我們先看下面的一個例子,紅色的直線表示當(dāng)前的學(xué)習(xí)模型,藍色的點表示測試數(shù)據(jù),通過下圖我們可以看到只有一個點跟模型擬合,其他的都欠擬合。從直覺上我們知道,當(dāng)前的模型可能沒那么好,那么如何定量的去衡量這個模型的好壞呢?
我們?nèi)匀灰陨蠄D為例,取x = 12這個樣本值,那么相應(yīng)藍色點的y值即為真實值,而相應(yīng)紅線所在的prediction(x)即表示預(yù)測值。那么y - prediction(x)即為損失。那么衡量這種損失的函數(shù),我們稱之為損失函數(shù)。
目前有這么幾種損失函數(shù):
平方損失,又稱L2損失
殘差平方和(RSS, Residual of Sum of Square)
均方誤差(MSE, Mean Square Error)
均方根誤差(RMSE, Root Mean Square Error)
其實都很好理解,我們一個一個來看。
平方損失
平方損失最好理解,單個樣本的平方損失,可以有如下公式給出:
其實即為真實值與預(yù)測值之間差值的平方。
******殘差平方和******
殘差平方和指的是將單個樣本的真實值與預(yù)測值之差的平方相加求和:
也即所有樣本的平方損失之和。
****均方誤差****
均方誤差指的是平均平方損失,也即需要計算所有樣本的平方損失之和,然后再除以樣本之和??梢杂扇缦鹿浇o出:
其中(x,y)表示樣本,x指的是樣本的特征值,y指的是樣本的標簽(或者說目標值)。
D表示數(shù)據(jù)集,prediction(x)指的是模型預(yù)測值,N表示樣本個數(shù)。
******均方根誤差******
均方根誤差與均方誤差,從字面意思上來講其實有很大的聯(lián)系,均方根誤差即為均方誤差的開平方根。
在實際編程應(yīng)用中,經(jīng)常會先求出MSE,然后對其開平方即可得到RMSE。
了解完所有的損失函數(shù)之后,我們不禁有個疑問,這些誤差到底在機器學(xué)習(xí)中起到了什么作用?
文章的開頭我們指出,我們需要利用損失函數(shù)來評估模型的優(yōu)劣,或者說我們通過損失函數(shù)的取值來調(diào)整模型的參數(shù),以期獲得最合適的模型。既然損失函數(shù)描述的是預(yù)測值與真實值的偏離情況,那么在實際訓(xùn)練過程中,我們應(yīng)該找到使損失函數(shù)取最小值時取對應(yīng)模型的參數(shù)。
我們假設(shè)目標和特征之間存在線性相關(guān)性,于是就有公式 y' = wx + b (假設(shè)只有一個特征)。以MSE(均方誤差)及線性回歸為例,我們將公式帶入MSE中替換預(yù)測值,即可得如下公式:
該公式是以w,b為參數(shù)的函數(shù),我們的目標是求取w,b的值使得該函數(shù)取最小值。由于N是一個常數(shù),我們可以簡化成下面的公式:
機器學(xué)習(xí)的模型訓(xùn)練問題到此就被我們簡化成了求解上面公式(2)的最小值,求解函數(shù)最小值的方法有很多,比較常見的有:
最小二乘法
梯度下降法
最小二乘法我們這里省略不講解,本文只講解梯度下降法。
03
迭代法與梯度下降
講解梯度下降法之前,我們先了解一下機器學(xué)習(xí)訓(xùn)練過程中的迭代法。
在機器學(xué)習(xí)過程中,我們的會選取一個初始化參數(shù)(w,b),然后通過損失函數(shù)來不斷地調(diào)整參數(shù)值(w,b)直到訓(xùn)練處一個符合預(yù)期的模型。下圖是一個簡化版本的機器學(xué)習(xí)迭代過程:
圖例來自Google機器學(xué)習(xí)文檔
上圖中,“計算參數(shù)更新”就是梯度下降法的本質(zhì)。在回歸問題中,權(quán)重值w的取值和損失之間的關(guān)系如下圖,始終是一個凸函數(shù)(凸函數(shù),用幾何的方式來理解就是圖形內(nèi)任意兩點構(gòu)成的線段仍然被圖形包圍)。也就是說找到了一個極小值點,那么這個極小值點就是全局最小值點。
圖例來自Google機器學(xué)習(xí)文檔
上圖中,任意選取一個起點。由于我們能夠看到圖形全貌,我們很快就能找到損失下降的方向,并最終定位到極小值點。但是在實際應(yīng)用中,我們?nèi)我膺x取的起點后,是不能立刻確定哪個方向是損失下降的方向。舉個例子,常識告訴我們地球是圓的,但站著地球上,我們始終感受不到弧度,也就無從找到某段圓弧的最低點。同樣,我們就需要通過某種方式判斷哪個方向是損失下降的方向。
在熟知的高數(shù)中,我們知道導(dǎo)數(shù)可以理解成斜率,或者曲線某個點的導(dǎo)數(shù)可以理解成改點切線的斜率。導(dǎo)數(shù)為正,則曲線遞增;反之,曲線遞減。梯度就可以理解成導(dǎo)數(shù)(偏導(dǎo)數(shù))。
梯度下降就是按照負梯度的方向,逐漸的去找到最小值點。
04
模型訓(xùn)練
了解了上面的線性回歸的基本原理以及模型訓(xùn)練的目的之后,我們總結(jié)一下線性回歸模型訓(xùn)練的套路:
分析訓(xùn)練數(shù)據(jù),找到數(shù)據(jù)的合適的分布
確定問題的類型,回歸還是分類問題
加載數(shù)據(jù)集,輸入訓(xùn)練數(shù)據(jù)
開始訓(xùn)練模型,迭代過程
使用測試數(shù)據(jù)來分析當(dāng)前模型的優(yōu)劣
其他(后面再講)
可以參照google機器學(xué)習(xí)文檔來詳細了解如何使用tensorflow訓(xùn)練模型。接下來會寫一篇TensorFlow模型訓(xùn)練初體驗。
** 結(jié)語 **
線性回歸是機器學(xué)習(xí)中常見的一種類型,通過損失函數(shù)來反復(fù)迭代機器學(xué)習(xí)模型。在迭代過程中采用了梯度下降法,尋找最優(yōu)解,整個流程比較容易理解。
關(guān)注公眾賬號<歲與禾>,獲取最新信息