李航統(tǒng)計學(xué)習(xí)方法(三)----k近鄰算法

k近鄰算法

給定一個訓(xùn)練數(shù)據(jù)集,對新的輸入實(shí)例,在訓(xùn)練數(shù)據(jù)集中找到跟它最近的k個實(shí)例,根據(jù)這k個實(shí)例的類判斷它自己的類(一般采用多數(shù)表決的方法)。

image.png

k近鄰模型

模型有3個要素——距離度量方法、k值的選擇和分類決策規(guī)則。

模型

當(dāng)3要素確定的時候,對任何實(shí)例(訓(xùn)練或輸入),它所屬的類都是確定的,相當(dāng)于將特征空間分為一些子空間。


image.png

距離度量
對n維實(shí)數(shù)向量空間Rn,經(jīng)常用Lp距離或曼哈頓Minkowski距離。

Lp距離定義如下:


image

當(dāng)p=2時,稱為歐氏距離:

image

當(dāng)p=1時,稱為曼哈頓距離:

image

當(dāng)p=∞,它是各個坐標(biāo)距離的最大值,即:

image

用圖表示如下:


image.png

k值的選擇

k較小,容易被噪聲影響,發(fā)生過擬合。

k較大,較遠(yuǎn)的訓(xùn)練實(shí)例也會對預(yù)測起作用,容易發(fā)生錯誤。

分類決策規(guī)則

使用0-1損失函數(shù)衡量,那么誤分類率是:

image

Nk是近鄰集合,要使左邊最小,右邊的

image

必須最大,所以多數(shù)表決=經(jīng)驗(yàn)最小化。

k近鄰法的實(shí)現(xiàn):kd樹

算法核心在于怎么快速搜索k個近鄰出來,樸素做法是線性掃描,不可取,這里介紹的方法是kd樹。

構(gòu)造kd樹

對數(shù)據(jù)集T中的子集S初始化S=T,取當(dāng)前節(jié)點(diǎn)node=root取維數(shù)的序數(shù)i=0,對S遞歸執(zhí)行:

找出S的第i維的中位數(shù)對應(yīng)的點(diǎn),通過該點(diǎn),且垂直于第i維坐標(biāo)軸做一個超平面。該點(diǎn)加入node的子節(jié)點(diǎn)。該超平面將空間分為兩個部分,對這兩個部分分別重復(fù)此操作(S=S',++i,node=current),直到不可再分。


image.png
   T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
     
    class node:
        def __init__(self, point):
            self.left = None
            self.right = None
            self.point = point
            pass
        
    def median(lst):
        m = len(lst) / 2
        return lst[m], m
     
    def build_kdtree(data, d):
        data = sorted(data, key=lambda x: x[d])
        p, m = median(data)
        tree = node(p)
     
        del data[m]
        print data, p
     
        if m > 0: tree.left = build_kdtree(data[:m], not d)
        if len(data) > 1: tree.right = build_kdtree(data[m:], not d)
        return tree
     
    kd_tree = build_kdtree(T, 0)
    print kd_tree

可視化

可視化的話則要費(fèi)點(diǎn)功夫保存中間結(jié)果,并恰當(dāng)?shù)卣故境鰜?/p>

    # -*- coding:utf-8 -*-
    # Filename: kdtree.py
    # Author:hankcs
    # Date: 2015/2/4 15:01
    import copy
    import itertools
    from matplotlib import pyplot as plt
    from matplotlib.patches import Rectangle
    from matplotlib import animation
     
    T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
     
     
    def draw_point(data):
        X, Y = [], []
        for p in data:
            X.append(p[0])
            Y.append(p[1])
        plt.plot(X, Y, 'bo')
     
     
    def draw_line(xy_list):
        for xy in xy_list:
            x, y = xy
            plt.plot(x, y, 'g', lw=2)
     
     
    def draw_square(square_list):
        currentAxis = plt.gca()
        colors = itertools.cycle(["r", "b", "g", "c", "m", "y", '#EB70AA', '#0099FF'])
        for square in square_list:
            currentAxis.add_patch(
                Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
                          color=next(colors)))
     
     
    def median(lst):
        m = len(lst) / 2
        return lst[m], m
     
     
    history_quare = []
     
     
    def build_kdtree(data, d, square):
        history_quare.append(square)
        data = sorted(data, key=lambda x: x[d])
        p, m = median(data)
     
        del data[m]
        print data, p
     
        if m >= 0:
            sub_square = copy.deepcopy(square)
            if d == 0:
                sub_square[1][0] = p[0]
            else:
                sub_square[1][1] = p[1]
            history_quare.append(sub_square)
            if m > 0: build_kdtree(data[:m], not d, sub_square)
        if len(data) > 1:
            sub_square = copy.deepcopy(square)
            if d == 0:
                sub_square[0][0] = p[0]
            else:
                sub_square[0][1] = p[1]
            build_kdtree(data[m:], not d, sub_square)
     
     
    build_kdtree(T, 0, [[0, 0], [10, 10]])
    print history_quare
     
     
    # draw an animation to show how it works, the data comes from history
    # first set up the figure, the axis, and the plot element we want to animate
    fig = plt.figure()
    ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
    line, = ax.plot([], [], 'g', lw=2)
    label = ax.text([], [], '')
     
    # initialization function: plot the background of each frame
    def init():
        plt.axis([0, 10, 0, 10])
        plt.grid(True)
        plt.xlabel('x_1')
        plt.ylabel('x_2')
        plt.title('build kd tree (www.hankcs.com)')
        draw_point(T)
     
     
    currentAxis = plt.gca()
    colors = itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", '#EB70AA', '#0099FF', '#66FFFF'])
     
    # animation function.  this is called sequentially
    def animate(i):
        square = history_quare[i]
        currentAxis.add_patch(
            Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
                      color=next(colors)))
        return
     
    # call the animator.  blit=true means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_quare), interval=1000, repeat=False,
                                   blit=False)
    plt.show()
    anim.save('kdtree_build.gif', fps=2, writer='imagemagick')
image.png

搜索kd樹

上面的代碼其實(shí)并沒有搜索kd樹,現(xiàn)在來實(shí)現(xiàn)搜索。

搜索跟二叉樹一樣來,是一個遞歸的過程。先找到目標(biāo)點(diǎn)的插入位置,然后往上走,逐步用自己到目標(biāo)點(diǎn)的距離畫個超球體,用超球體圈住的點(diǎn)來更新最近鄰(或k最近鄰)。以最近鄰為例,實(shí)現(xiàn)如下(本實(shí)現(xiàn)由于測試數(shù)據(jù)簡單,沒有做超球體與超立體相交的邏輯):

    # -*- coding:utf-8 -*-
    # Filename: search_kdtree.py
    # Author:hankcs
    # Date: 2015/2/4 15:01
     
    T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
     
     
    class node:
        def __init__(self, point):
            self.left = None
            self.right = None
            self.point = point
            self.parent = None
            pass
     
        def set_left(self, left):
            if left == None: pass
            left.parent = self
            self.left = left
     
        def set_right(self, right):
            if right == None: pass
            right.parent = self
            self.right = right
     
     
    def median(lst):
        m = len(lst) / 2
        return lst[m], m
     
     
    def build_kdtree(data, d):
        data = sorted(data, key=lambda x: x[d])
        p, m = median(data)
        tree = node(p)
     
        del data[m]
     
        if m > 0: tree.set_left(build_kdtree(data[:m], not d))
        if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
        return tree
     
     
    def distance(a, b):
        print a, b
        return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
     
     
    def search_kdtree(tree, d, target):
        if target[d] < tree.point[d]:
            if tree.left != None:
                return search_kdtree(tree.left, not d, target)
        else:
            if tree.right != None:
                return search_kdtree(tree.right, not d, target)
     
        def update_best(t, best):
            if t == None: return
            t = t.point
            d = distance(t, target)
            if d < best[1]:
                best[1] = d
                best[0] = t
     
     
        best = [tree.point, 100000.0]
        while (tree.parent != None):
            update_best(tree.parent.left, best)
            update_best(tree.parent.right, best)
            tree = tree.parent
        return best[0]
     
     
    kd_tree = build_kdtree(T, 0)
    print search_kdtree(kd_tree, 0, [9, 4])

輸出

[8, 1] [9, 4]
[5, 4] [9, 4]  
[9, 6] [9, 4]
[9, 6]
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容