樹回歸

上一篇文章中,我們比較全面地學(xué)習(xí)了線性回歸的原理是實(shí)現(xiàn),今天我們還是留在回歸板塊,針對(duì)樹回歸進(jìn)行學(xué)習(xí)和實(shí)踐。

01 樹回歸原理

相比于線性回歸,樹回歸更適合對(duì)復(fù)雜、非線性的數(shù)據(jù)進(jìn)行回歸建模。

原理

回想一下決策樹,樹回歸的原理就是決策樹(人家都叫”樹“回歸了……),在決策樹的學(xué)習(xí)中,有三種算法,ID3, C4.5, CART,前兩種算法只能處理離散型數(shù)據(jù),因此只能用于回歸,而CART算法由于采用二分法構(gòu)建樹,可以處理連續(xù)性數(shù)據(jù),因此也可以用于回歸,樹回歸的基本原理,就是CART算法。

混亂度衡量

說完算法原理,我們?cè)賮碚f說對(duì)于連續(xù)性數(shù)據(jù)的混亂度度量。我們知道,對(duì)于離散型數(shù)據(jù),可以使用信息增益、信息增益比、基尼指數(shù)這些指標(biāo)來衡量數(shù)據(jù)的混亂程度,那么對(duì)于連續(xù)型數(shù)據(jù),怎么衡量呢?

可以使用總方差來衡量連續(xù)性數(shù)據(jù)的混亂程度:(各數(shù)據(jù)-數(shù)據(jù)均值)**2,即方差*樣本數(shù),稱為總方差。

兩種樹回歸

樹回歸有兩種方式:回歸樹和模型樹,其中,

  • 回歸樹:葉節(jié)點(diǎn)是一個(gè)值:當(dāng)前葉子所有樣本標(biāo)簽均值
  • 模型樹:葉節(jié)點(diǎn)是一個(gè)線性回歸模型:當(dāng)前葉子所有樣本的線性回歸模型

下面我們分別實(shí)現(xiàn)回歸樹和模型樹。


02 回歸樹實(shí)現(xiàn)

  • 葉節(jié)點(diǎn)是一個(gè)值:當(dāng)前葉子所有樣本標(biāo)簽均值
  • 誤差衡量:總方差,表示一組數(shù)據(jù)的混亂度,是本組所有數(shù)據(jù)與這組數(shù)據(jù)均值之差的平方和

回歸樹的構(gòu)建邏輯:二分法,每次選擇一個(gè)最佳特征,并找到最佳切分特征值(使數(shù)據(jù)混亂度減少最多的[特征,特征值])進(jìn)行切分,得到左右子樹,然后對(duì)左右子樹遞歸調(diào)用createTree方法,直到?jīng)]有最佳特征為止。(實(shí)踐中,在選擇最佳特征時(shí),進(jìn)行了預(yù)剪枝)

#數(shù)據(jù)讀取
def loadDataSet(filename):
    dataMat=[]
    fr=open(filename,'r')
    for line in fr.readlines():
        curLine=line.strip().split('\t')
        fltLine=list(map(float,curLine)) #將curLine各元素轉(zhuǎn)換為float類型
        dataMat.append(fltLine)
    return mat(dataMat)

#二分?jǐn)?shù)據(jù)
def binSplitDataSet(dataset,feat,val):
    mat0=dataset[nonzero(dataset[:,feat]>val)[0],:] #數(shù)組過濾選擇特征大于指定值的數(shù)據(jù)
    mat1=dataset[nonzero(dataset[:,feat]<=val)[0],:] #數(shù)組過濾選擇特征小于指定值的數(shù)據(jù)
    return mat0,mat1

#定義回歸樹的葉子(該葉子上各樣本標(biāo)簽的均值)
def regLeaf(dataset):
    return mean(dataset[:,-1])

#定義連續(xù)數(shù)據(jù)的混亂度(總方差,即連續(xù)數(shù)據(jù)的混亂度=(該組各數(shù)據(jù)-該組數(shù)據(jù)均值)**2,即方差*樣本數(shù)))
def regErr(dataset):
    return var(dataset[:-1])*shape(dataset)[0]

"""最佳特征以及最佳特征值選擇函數(shù)"""
#leafType為葉節(jié)點(diǎn)取值,默認(rèn)為regleaf,即取樣本標(biāo)簽均值,對(duì)于模型樹,葉節(jié)點(diǎn)是一個(gè)線性模型
#errType為數(shù)據(jù)誤差(混亂度)計(jì)算方式,默認(rèn)為regErr,總方差
#ops[0]為以最佳特征及特征值切分?jǐn)?shù)據(jù)前后,數(shù)據(jù)混亂度的變化閾值,若小于該閾值,不切分
#ops[1]為切分后兩塊數(shù)據(jù)的最少樣本數(shù),若少于該值,不切分
#可以預(yù)想,回歸樹形狀對(duì)ops[0],ops[1]很敏感,若這兩個(gè)值過小,回歸樹會(huì)很臃腫,過擬合
def chooseBestSplit(dataset,leafType=regLeaf,errType=regErr,ops=(1,4)):
    tolS=ops[0];tolN=ops[1];m,n=shape(dataset)
    S=errType(dataset);bestS=inf;beatIndex=0;bestVal=0
    if len(set(dataset[:,-1].T.tolist()[0]))==1: #若只有一個(gè)類別
        return None,leafType(dataset)
    for featIndex in range(n-1):
        for splitVal in set(dataset[:,featIndex].T.tolist()[0]):
            mat0,mat1=binSplitDataSet(dataset,featIndex,splitVal)
            #若切分后兩塊數(shù)據(jù)的最少樣本數(shù)少于設(shè)定值,不切分
            if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN): 
                continue
            newS=errType(mat0)+errType(mat1)
            if newS<bestS:
                bestIndex=featIndex;bestVal=splitVal;bestS=newS
    #若以最佳特征及特征值切分后的數(shù)據(jù)混亂度與原數(shù)據(jù)混亂度差值小于閾值,不切分
    if (S-bestS)<tolS:
        return None,leafType(dataset)
    mat0,mat1=binSplitDataSet(dataset,bestIndex,bestVal)
    #若以最佳特征及特征值切分后兩塊數(shù)據(jù)的最少樣本數(shù)少于設(shè)定值,不切分
    if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
        return None,leafType(dataset)
    return bestIndex,bestVal

"""構(gòu)建回歸樹"""
def createTree(dataset,leafType=regLeaf,errType=regErr,ops=(1,4)):
    feat,val=chooseBestSplit(dataset,leafType,errType,ops)
    if feat==None:
        return val
    regTree={}
    regTree['spFeat']=feat
    regTree['spVal']=val
    lSet,rSet=binSplitDataSet(dataset,feat,val)
    regTree['left']=createTree(lSet,leafType,errType,ops)
    regTree['right']=createTree(rSet,leafType,errType,ops)
    return regTree

好了好了,寫了這么多代碼,我們來測(cè)試一下,原始數(shù)據(jù)分布如下圖,訓(xùn)練結(jié)果如下圖。可以看到,回歸樹模型將這組數(shù)據(jù)分到了5個(gè)葉節(jié)點(diǎn)上,目前看起來還過得去。


03 回歸樹剪枝

當(dāng)我們?cè)O(shè)置的最小分離葉節(jié)點(diǎn)樣本數(shù)、最小混亂度減小值等參數(shù)過小,可能產(chǎn)生過擬合,直觀的現(xiàn)象就是,訓(xùn)練出來非常多的葉子,其實(shí)是沒有必要的,此時(shí)就需要剪枝了(很形象嘛)。

剪枝分為預(yù)剪枝和后剪枝,

  • 預(yù)剪枝:在chooseBestSplit函數(shù)中的幾個(gè)提前終止條件(切分樣本小于閾值、混亂度減弱小于閾值),都是預(yù)剪枝(參數(shù)敏感)。
  • 后剪枝:使用測(cè)試集對(duì)訓(xùn)練出的回歸樹進(jìn)行剪枝(由于不需要用戶指定,后剪枝是一種更為理想化的剪枝方法)

后剪枝邏輯:對(duì)訓(xùn)練好的回歸樹,自上而下找到葉節(jié)點(diǎn),用測(cè)試集來判斷將這些葉節(jié)點(diǎn)合并是否能降低測(cè)試誤差,若能,則合并。

#判斷是否是一棵樹(字典)
def isTree(obj):
    return (type(obj).__name__=='dict')

#得到樹所有葉節(jié)點(diǎn)的均值
def getMean(tree):
    #若子樹仍然是樹,則遞歸調(diào)用getMeant直到葉節(jié)點(diǎn)
    if isTree(tree['left']):
        tree['left']=getMean(tree['left'])
    if isTree(tree['right']):
        tree['right']=getMean(tree['right'])
    return (tree['left']+tree['right'])/2.0

"""剪枝函數(shù):對(duì)訓(xùn)練好的回歸樹,自上而下找到葉節(jié)點(diǎn),用測(cè)試集來判斷將這些葉節(jié)點(diǎn)合并是否能降低測(cè)試誤差,若能則合并"""
def prune(tree,testData):
    #若無測(cè)試數(shù)據(jù),則直接返回樹所有葉節(jié)點(diǎn)的均值(塌陷處理)
    if shape(testData)[0]==0:
        return getMean(tree)
    #若存在任意子集是樹,則將測(cè)試集按當(dāng)前樹的最佳切分特征和特征值切分(子集剪枝用)
    if isTree(tree['left']) or isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
    #若存在任意子集是樹,則該子集遞歸調(diào)用剪枝過程(利用剛才切分好的訓(xùn)練集)
    if isTree(tree['left']):
        tree['left']=prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right']=prune(tree['right'],rSet)
    #若當(dāng)前子集都是葉節(jié)點(diǎn),則計(jì)算該二葉節(jié)點(diǎn)合并前后的誤差,決定是否合并
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        errNotMerge=sum(power(lSet[:,-1].T.tolist()[0]-tree['left'],2))+sum(power(rSet[:,-1].T.tolist()[0]-tree['right'],2))
        treeMean=(tree['left']+tree['right'])/2.0
        errMerge=sum(power(testData[:,-1].T.tolist()[0]-treeMean,2))
        if errMerge<errNotMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

測(cè)試一下

#構(gòu)建回歸樹,可以看到,該回歸樹非常臃腫,過擬合
dataMat3=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2.txt')
regTree3=createTree(dataMat3,ops=(1,2))
testData=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2test.txt')
prune(regTree3,testData)

結(jié)果如下,

可以看到,雖然有6個(gè)葉節(jié)點(diǎn)被剪掉了,但仍然有很多葉節(jié)點(diǎn)保留->后剪枝可能不如預(yù)剪枝有效,因此一般為了尋求最佳模型,會(huì)同時(shí)使用兩種剪枝技術(shù)。


04 模型樹實(shí)現(xiàn)

  • 葉節(jié)點(diǎn)是一個(gè)線性回歸模型:當(dāng)前葉子所有樣本的線性回歸模型
  • 誤差衡量:平方誤差類比線性回歸誤差,用線性模型對(duì)數(shù)據(jù)擬合,計(jì)算真實(shí)值與擬合值之差,求差值的平方和
  • 比回歸樹有更好的可解釋性、更高的預(yù)測(cè)準(zhǔn)確度

模型樹構(gòu)建邏輯:通用函數(shù),算法邏輯與createTree()一致,只需改變其中的葉節(jié)點(diǎn)計(jì)算方法leafType()和誤差計(jì)算方法errType()

#葉節(jié)點(diǎn)計(jì)算方法:該葉節(jié)點(diǎn)所有樣本的標(biāo)準(zhǔn)線性回歸模型,算法與linearRegression()一致
def linearSolve(dataset):
    m,n=shape(dataset)
    X=mat(ones((m,n)));Y=mat(ones((m,1)))
    X[:,1:n]=dataset[:,0:n-1] #X第一列為常數(shù)項(xiàng)1
    Y=dataset[:,-1]
    xTx=X.T*X
    if linalg.det(xTx)==0.0:
        raise NameError("矩陣為奇異矩陣,不可逆,嘗試增大ops的第二個(gè)參數(shù)")
    ws=xTx.I*(X.T*Y)
    return ws,X,Y

def modelLeaf(dataset):
    ws,X,Y=linearSolve(dataset)
    return ws

#誤差計(jì)算方法:用線性模型對(duì)數(shù)據(jù)擬合,計(jì)算真實(shí)值與擬合值之差,求差值的平方和
def modelErr(dataset):
    ws,X,Y=linearSolve(dataset)
    yPred=X*ws
    return sum(power(Y-yPred,2))

測(cè)試一下,訓(xùn)練集如圖藍(lán)點(diǎn)所示,訓(xùn)練模型如圖紅線所示,可以看到,模型樹對(duì)數(shù)據(jù)的預(yù)測(cè)更準(zhǔn)確合理。

dataMat4=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\exp2.txt')
modelTree1=createTree(dataMat4,leafType=modelLeaf,errType=modelErr,ops=(1,10))

05 模型樹剪枝

同樣地,模型樹也會(huì)出現(xiàn)過擬合,也需要剪枝,原理與回歸樹剪枝一樣,只需要替換其中的誤差計(jì)算方式,然后微調(diào)一下剪枝代碼,讓每次遞歸時(shí),對(duì)訓(xùn)練數(shù)據(jù)也遞歸切分。

#判斷是否是一棵樹(字典)
def isTree(obj):
    return (type(obj).__name__=='dict')

#剪枝函數(shù):對(duì)訓(xùn)練好的模型樹,自上而下找到葉節(jié)點(diǎn),用測(cè)試集來判斷將這些葉節(jié)點(diǎn)合并是否能降低測(cè)試誤差,若能則合并
def modelPrune(tree,trainData,testData):
    m,n=shape(testData)
    #若無測(cè)試數(shù)據(jù),則直接返回樹所有葉節(jié)點(diǎn)的均值(塌陷處理)
    if m==0:
        return tree
    #若存在任意子集是樹,則將測(cè)試集按當(dāng)前樹的最佳切分特征和特征值切分(子集剪枝用)
    #同時(shí)將訓(xùn)練集也按當(dāng)前樹的最佳切分特征和特征值切分(子集剪枝用)
    if isTree(tree['left']) or isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        lTrain,rTrain=binSplitDataSet(trainData,tree['spFeat'],tree['spVal'])
    #若存在任意子集是樹,則該子集遞歸調(diào)用剪枝過程(利用剛才切分好的訓(xùn)練集)
    if isTree(tree['left']):
        tree['left']=modelPrune(tree['left'],lTrain,lSet)
    if isTree(tree['right']):
        tree['right']=modelPrune(tree['right'],rTrain,rSet)
        
    #若當(dāng)前子集都是葉節(jié)點(diǎn),則計(jì)算該二葉節(jié)點(diǎn)合并前后的誤差,決定是否合并
    """
    模型樹,兩個(gè)葉節(jié)點(diǎn)合并前的誤差=((左葉子真實(shí)值-擬合值)的平方和+(右葉子真實(shí)值-擬合值)的平方和)
    模型樹,兩個(gè)葉節(jié)點(diǎn)合并后的誤差=(左右真實(shí)值-左右擬合值)的平方和
    難點(diǎn)在于如何求左右擬合值,即求上層節(jié)點(diǎn)的回歸系數(shù)wsMerge:用上層節(jié)點(diǎn)的traindata,通過linearSolve(traindata)求得
    上層節(jié)點(diǎn)的traindata在lTrain,rTrain的遞歸中已經(jīng)求好了
    """
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        lSetX=mat(ones((shape(lSet)[0],n)));rSetX=mat(ones((shape(rSet)[0],n)))
        lSetX[:,1:n]=lSet[:,0:n-1];rSetX[:,1:n]=rSet[:,0:n-1]
        errNotMerge=sum(power(array(lSet[:,-1].T.tolist()[0])-lSetX*tree['left'],2))+sum(power(array(rSet[:,-1].T.tolist()[0])-rSetX*tree['right'],2))
        #難點(diǎn)在于求上層節(jié)點(diǎn)的回歸系數(shù)wsMerge:用上層節(jié)點(diǎn)的traindata,通過linearSolve(traindata)求得
        wsMerge=modelLeaf(trainData)  
        testDataX=mat(ones((m,n)));testDataX[:,1:n]=testData[:,0:n-1]
        errMerge=sum(power(array(testData[:,-1].T.tolist()[0])-testDataX*wsMerge,2))
        if errMerge<errNotMerge:
            print("merging")
            return wsMerge
        else:
            return tree
    else:
        return tree

測(cè)試一下,


06 模型預(yù)測(cè)效果對(duì)比

本次我們構(gòu)建了回歸樹和模型樹,順便構(gòu)建了一個(gè)線性回歸函數(shù),我們先來看看對(duì)于同一組數(shù)據(jù),這三個(gè)模型的預(yù)測(cè)效果吧,這里使用R2值來評(píng)估預(yù)測(cè)效果。

結(jié)果如下,

  • 可以看到,這此數(shù)據(jù)集上,模型樹表現(xiàn)比回歸樹好,線性回歸表現(xiàn)最差
  • 說明樹回歸相比于線性回歸,可以更好地處理復(fù)雜、非線性的數(shù)據(jù)集

07 總結(jié)

至此,我們基本上學(xué)習(xí)了回歸任務(wù)中80%以上的算法模型,主要分為線性回歸模型和樹回歸模型(再回憶一下,KNN也可以用于回歸),它們各有優(yōu)劣,針對(duì)具有不同特點(diǎn)的數(shù)據(jù)集,要選擇合適的算法。


08 參考

  • 《機(jī)器學(xué)習(xí)實(shí)戰(zhàn)》 Peter Harrington Chapter9
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容