簡介
KNN是很多人接觸機(jī)器學(xué)習(xí)的第一個(gè)算法,我也不例外。在利用OpenCV (C++)結(jié)合KNN處理MNIST數(shù)據(jù),遇到了很多的坑,在這里和各位分享一下心得。
完整代碼在這里,喜歡的可以Star,不喜歡的可以提建議!
環(huán)境是MacOS + OpenCV4
關(guān)鍵步驟概覽
關(guān)鍵步驟的代碼取自于我實(shí)現(xiàn)的部分,這里只是闡述關(guān)鍵步驟和一些心得,詳細(xì)地可以看我代碼,比較容易看懂的!
- 獲得MNIST的訓(xùn)練集(包含圖片和數(shù)據(jù))
bool get_train_images_with_label_from_mnist(cv::Ptr<cv::ml::TrainData> &trainData)
- 獲得MNIST的測試集(包含圖片和數(shù)據(jù))
bool get_test_images_with_label_from_mnist(cv::Mat &testData, cv::Mat &testLabel)
- 創(chuàng)建KNN模型,并設(shè)定一些基本的參數(shù)。
Ptr<ml::KNearest> knn_model = ml::KNearest::create();
knn_model->setDefaultK(K_value); // 指明KNN的K
knn_model->setIsClassifier(true); // 指明這個(gè)KNN是用來分類的
knn_model->setAlgorithmType(cv::ml::KNearest::Types::BRUTE_FORCE);
- 訓(xùn)練剛剛創(chuàng)建的KNN模型
knn_model->train(training_set, 0); // 利用訓(xùn)練集訓(xùn)練KNN
- 用
findNearest進(jìn)行預(yù)測
knn_model->findNearest(test_set, knn_model->getDefaultK(), result_set);
注意: 這里的result_set的結(jié)果返回的是CV_32F的類型,也就是說里面的元素是32位的float,可能會(huì)和我們之后用的標(biāo)記(可能會(huì)用int32_t來存儲(chǔ)),所以需要static_cast。
- 利用測試集的標(biāo)記
testLabel和result_set的比較來計(jì)算預(yù)測準(zhǔn)確率。(如果它們類型不一樣,比如一個(gè)是float32,另一個(gè)是int32,請記得cast)
如何處理MNIST數(shù)據(jù)集
這里給出3個(gè)關(guān)鍵的提示
MNIST 數(shù)據(jù)集是用大端的方式存儲(chǔ)的,用Intel處理器的PC機(jī)一般是小端存儲(chǔ)的,需要做轉(zhuǎn)換。
cv::ml::TrainData::create()只能處理CV_32F類型的,也就是32位float, 但是NMIST中的像素是用unsigned byte存的。MNIST中的圖片是二維的,但你需要把它轉(zhuǎn)存成一維的數(shù)組以便于它被
cv::Mat處理。
完整代碼
再說一遍,完整代碼位置:
https://github.com/VinStarry/CV_codes/tree/master/elementary/knn
測試結(jié)果
準(zhǔn)確率與K的取值散點(diǎn)圖

錯(cuò)誤結(jié)果示例
-
預(yù)測結(jié)果:9, 實(shí)際數(shù)字:4
knn_wrongcase_1.jpg -
預(yù)測結(jié)果:6,實(shí)際數(shù)字:4
knn_wrongcase_2.jpg -
預(yù)測結(jié)果:8,實(shí)際數(shù)字:9
knn_wrongcase_3.jpg


