KNN分類算法,是理論上比較成熟的方法,也是最簡單的機(jī)器學(xué)習(xí)算法之一。
該方法的思路是:如果一個樣本在特征空間中的k個最相似(即特征空間中最鄰近)的樣本中的大多數(shù)屬于某一個類別,則該樣本也屬于這個類別。
KNN算法中,所選擇的鄰居都是已經(jīng)正確分類的對象。該方法在定類決策上只依據(jù)最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。
一個對于KNN算法解釋最清楚的圖如下所示:

藍(lán)方塊和紅三角均是已有分類數(shù)據(jù),當(dāng)前的任務(wù)是將綠色圓塊進(jìn)行分類判斷,判斷是屬于藍(lán)方塊或者紅三角。
當(dāng)然這里的分類還跟K值是有關(guān)的:
如果K=3(實線圈),紅三角占比2/3,則判斷為紅三角;
如果K=5(虛線圈),藍(lán)方塊占比3/5,則判斷為藍(lán)方塊。
由此可以看出knn算法實際上根本就不用進(jìn)行訓(xùn)練,而是直接進(jìn)行計算的,訓(xùn)練時間為0,計算時間為訓(xùn)練集規(guī)模n。
knn算法的基本要素大致有3個:
1、K 值的選擇
2、距離的度量
3、分類決策規(guī)則
使用方式
1、K 值會對算法的結(jié)果產(chǎn)生重大影響。K值較小意味著只有與輸入實例較近的訓(xùn)練實例才會對預(yù)測結(jié)果起作用,容易發(fā)生過擬合;如果 K 值較大,優(yōu)點是可以減少學(xué)習(xí)的估計誤差,缺點是學(xué)習(xí)的近似誤差增大,這時與輸入實例較遠(yuǎn)的訓(xùn)練實例也會對預(yù)測起作用,是預(yù)測發(fā)生錯誤。在實際應(yīng)用中,K 值一般選擇一個較小的數(shù)值,通常采用交叉驗證的方法來選擇最有的 K 值。隨著訓(xùn)練實例數(shù)目趨向于無窮和 K=1 時,誤差率不會超過貝葉斯誤差率的2倍,如果K也趨向于無窮,則誤差率趨向于貝葉斯誤差率。
2、算法中的分類決策規(guī)則往往是多數(shù)表決,即由輸入實例的 K 個最臨近的訓(xùn)練實例中的多數(shù)類決定輸入實例的類別
3、距離度量一般采用 Lp 距離,當(dāng)p=2時,即為歐氏距離,在度量之前,應(yīng)該將每個屬性的值規(guī)范化,這樣有助于防止具有較大初始值域的屬性比具有較小初始值域的屬性的權(quán)重過大。
knn算法在分類時主要的不足是,當(dāng)樣本不平衡時,如果一個類的樣本容量很大,而其他類樣本容量很小時,有可能導(dǎo)致當(dāng)輸入一個新樣本時,該樣本的 K 個鄰居中大容量類的樣本占多數(shù)。
KNN算法python2代碼實現(xiàn)
from numpy import *
import operator
from os import listdir
import matplotlib
import matplotlib.pyplot as plt
def classify0(inX, dataSet, labels, k):#對每組數(shù)據(jù)進(jìn)行分類,inX為用于分類的輸入向量,
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet#歐式距離計算tile函數(shù)位于python模塊 numpy.lib.shape_base中,他的功能是重復(fù)某個數(shù)組。比如tile(A,n),功能是將數(shù)組A重復(fù)n次,構(gòu)成一個新的數(shù)組
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#,將classCount字典分解為元祖列表,排序k個距離值
return sortedClassCount[0][0]#返回發(fā)生頻率最高的元素標(biāo)簽
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
def file2matrix(filename):#準(zhǔn)備數(shù)據(jù),文本文件解析數(shù)據(jù),處理輸入格式問題
fr = open(filename)
numberOfLines = len(fr.readlines()) ?#得到文件行數(shù) ? get the number of lines in the file
returnMat = zeros((numberOfLines,3)) ? ?#prepare matrix to return創(chuàng)建返回的numpy矩陣,這里的3 可以按照實際需求去改
classLabelVector = [] ? ? ? ? ? ?#prepare labels return
fr = open(filename)
index = 0
for line in fr.readlines():#解析文件數(shù)據(jù)到列表
line = line.strip()#截取掉所有的回車字符
#print line
listFromLine = line.split('\t')#分割成元素列表
#print listFromLine
returnMat[index,:] = listFromLine[0:3]#選取前3個元素存儲到特征矩陣中
#print returnMat[index,:]
classLabelVector.append(int(listFromLine[-1]))#將列表中的最后一列存儲到向量classLabelVector中,這里必須是int,否則會默認(rèn)字符串
index += 1
return returnMat,classLabelVector#輸出為訓(xùn)練樣本矩陣與類標(biāo)簽向量
def autoNorm(dataSet):#該函數(shù)可以自動將數(shù)字特征值轉(zhuǎn)換為0到1的區(qū)間
minVals = dataSet.min(0)#將每列的最小值放在變量minVals中
maxVals = dataSet.max(0)#將每列的最大值放在變量maxVals中
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))#tile()函數(shù)將變量內(nèi)容復(fù)制成輸入矩陣同樣大小的矩陣
normDataSet = normDataSet/tile(ranges, (m,1)) ?#element wise divide#/表示具體特征值相除而不是矩陣相除
return normDataSet, ranges, minVals
def datingClassTest():#計算錯誤率 評估分類器準(zhǔn)確率,該函數(shù)可以在任何時候測試分類器效果
hoRatio = 0.10 ? #hold out 10%
datingDataMat,datingLabels = file2matrix('C:/Users/HZF/Desktop/machinelearninginaction/Ch02/datingTestSet2.txt') ? ?#load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)#計算測試樣本的數(shù)量
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],4)#將訓(xùn)練與測試數(shù)據(jù)輸入到分類器函數(shù)classify0函數(shù)中
#print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is: %f" % (errorCount/float(numTestVecs))#改變hoRatio與k的值,檢驗錯誤率會有所不同。
print errorCount
def img2vector(filename):#在二進(jìn)制存儲的圖像數(shù)據(jù)上使用KNN,將圖像轉(zhuǎn)換為測試向量
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):#循環(huán)讀出文件的前32行
returnVect[0,32*i+j] = int(lineStr[j])#將32*32的二進(jìn)制圖像矩陣轉(zhuǎn)化為1*1024的向量
return returnVect#返回數(shù)組
from os import listdir
def handwritingClassTest():#測試算法,使用knn識別手寫數(shù)字
hwLabels = []
trainingFileList = listdir('C:/Users/HZF/Desktop/machinelearninginaction/Ch02/trainingDigits') ? ? ?#load the training set 列出給定目錄的文件名
m = len(trainingFileList)
trainingMat = zeros((m,1024))#訓(xùn)練矩陣,該矩陣的每行數(shù)據(jù)存儲一個圖像
for i in range(m):#從文件名解析分類數(shù)字
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] ? #take off .txt
classNumStr = int(fileStr.split('_')[0])#
hwLabels.append(classNumStr)#將類編碼存儲在hwLabels向量中
trainingMat[i,:] = img2vector('C:/Users/HZF/Desktop/machinelearninginaction/Ch02/trainingDigits/%s' % fileNameStr)
testFileList = listdir('C:/Users/HZF/Desktop/machinelearninginaction/Ch02/testDigits') ? ?#iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] ? #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('C:/Users/HZF/Desktop/machinelearninginaction/Ch02/testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount/float(mTest))
if __name__=="__main__":
datingDataMat,datingLabels=file2matrix('C:/Users/HZF/Desktop/machinelearninginaction/Ch02/datingTestSet2.txt')
print datingDataMat
#print datingLabels[:20]
normDataSet, ranges, minVals=autoNorm(datingDataMat)
print normDataSet
errorCount=datingClassTest()
handwritingClassTest()
#print errorCount
fig=plt.figure()#創(chuàng)建散點圖圖形化數(shù)據(jù),以辨識數(shù)據(jù)模式
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))#使用數(shù)據(jù)集的第2,3列數(shù)據(jù),利用變量datingLables存儲的類標(biāo)簽屬性,繪制色彩,尺寸不同的點
plt.show()
KNN算法上篇就寫這么多,歡迎評論!
參考文獻(xiàn)
1、[機(jī)器學(xué)習(xí)] --KNN K-最鄰近算法
2、《機(jī)器學(xué)習(xí)實戰(zhàn)》 (書)