
姓名:劉強
【嵌牛導讀】
手寫識別是計算機視覺的一個研究方向,可以看成是一個分類問題。機器學習的任務,便是解決分類(有監(jiān)督學習)、聚類(無監(jiān)督學習)和回歸(強化學習)問題。k-近鄰算法(簡稱kNN)是最簡單的有監(jiān)督學習算法,本文介紹了如何用k-近鄰算法構建一個手寫識別系統(tǒng),并附上其python實現。
【嵌牛鼻子】
k-近鄰算法 機器學習 分類 手寫識別
【嵌牛提問】
k-近鄰算法是什么? 如何構建一個手寫識別系統(tǒng)?
【嵌牛正文】
k近鄰算法基本思想
存在一個樣本數據集,稱為訓練集,訓練集中每個數據都存在標簽(標簽即數據所屬的類別,從這一點可以看出,k近鄰算法屬于有監(jiān)督學習)。對于不知道標簽的新數據,將新數據的每個特征與訓練集中數據對應的特征相比較,選出訓練集中前k個最相似的數據(這就是k-近鄰算法名稱中k的出處),然后對這k個數據做統(tǒng)計,選擇出現次數最多的標簽作為新數據的標簽(即k-近鄰算法的輸出)。
從其基本思想可以看出,k-近鄰算法用于解決分類問題。所謂近鄰,其實是用數據之間的歐氏距離來衡量它們的相似程度,距離越短,表示兩個數據越相似。

構建手寫識別系統(tǒng)
需求分析
很多輸入法都支持手寫輸入,實現手寫輸入通常的做法是把手寫的結果生成圖片,進行圖像識別。我們知道,圖片可以用矩陣表示,對于單通道的灰度圖像,假如分辨率為32X32,則可以用一個32X32的矩陣表示,矩陣中的每個元素表示圖片中該位置的像素,元素的值為0~255之間的灰度值。
而對于手寫圖片,表示方法則更加簡單,因為手寫圖片是只有黑白兩色的二值圖像,利用圖像處理軟件,黑色的位置寫1,白色背景寫0,將其轉成文本文件,如下圖所示:

雖然這樣表示不能有效利用內存空間(本來0/1只需占據1bit的空間,但是變成字符“0”,“1”之后需要用char類型所占的字節(jié)數),但是對于圖像到矩陣的轉換這一過程非常直觀,方便演示。
我們的目標是:將這樣的一幅“圖像”輸入我們的系統(tǒng),我們能夠輸出“圖像”中所顯示的數字(只做數字0~9的識別)。
系統(tǒng)組成
我們的手寫識別系統(tǒng)由以下部分組成:
- 已知標簽的訓練集
- 文件輸入輸出模塊
- kNN算法模塊
已知標簽的訓練集
點此下載:用到的數據及源代碼
其中,trainingDigits文件夾中存放的是用作訓練集的的圖片,其中包含了1934個訓練樣本,testDigits文件夾中存放的是用作測試集的圖片,其中包含了946個測試樣本。每個文件的文件名中含有它的標簽。
文件輸入輸出模塊
python讀文本文件相當簡單,為了迎合后續(xù)的kNN算法,我們不把圖像表示成32X32的矩陣形式,而是將其轉化成1X1024的向量,為此我們定義一個img2vector函數:
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
kNN算法模塊
根據上述對kNN算法的描述,kNN算法有如下步驟:
- 測試數據與訓練集中的每個數據進行比較,以這兩個數據間的歐氏距離作為測試數據和訓練數據間的相似性度量
- 將算出的歐式距離列表從小到大排序,取前k名所對應的訓練集中的數據
- 取出這k個數據的標簽,對數目進行統(tǒng)計,出現次數最多的標簽作為算法的輸出,即分類的結果
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
系統(tǒng)整體代碼
'''
kNN: k Nearest Neighbors
Input: inX: vector to compare to existing dataset (1xN)
dataSet: size m data set of known vectors (NxM)
labels: data set labels (1xM vector)
k: number of neighbors to use for comparison (should be an odd number)
Output: the most popular class label
'''
from numpy import *
import operator
from os import listdir
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') #load the training set
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (errorCount/float(mTest)))
系統(tǒng)測試
測試環(huán)境
- win10 64位
- python3.6.2
測試步驟
- 打開cmd,進入kNN.py所在的文件夾
- 輸入python進入python shell
- 輸入from kNN import *導入kNN模塊中所有函數
- 輸入handwritingClassTest(),回車
測試結果

從測試結果來看,1.0571%的錯誤率,準確度還是蠻高的……

增加訓練集的樣本容量能有效提高系統(tǒng)的準確度,但是同時增加了運算量,使計算耗時增加。