這一節(jié)我們主要講述如何使用預(yù)訓練模型。Ipython notebook鏈接在這里。
模型下載
你可以去Model Zoo下載預(yù)訓練好的模型,或者使用Caffe2的models.download模塊獲取預(yù)訓練的模型。caffe2.python.models.download需要模型的名字所謂參數(shù)。你可以去看看有什么模型可用,然后替換下面代碼中的squeezenet。
python -m caffe2.python.models.download -i squeezenet
譯者注:如果不明白為什么用python -m 執(zhí)行,可以看看這個帖子。
如果上面下載成功,那么你應(yīng)該下載了 squeezenet到你的文件夾中。如果你使用i那么模型文件將下載到/caffe2/python/models文件夾中。當然,你也可以下載所有模型文件:git clone https://github.com/caffe2/models。
Overview
在這個教程中,我們將會使用squeezenet模型進行圖片的目標識別。如果,你讀了前面的預(yù)處理章節(jié),那么你會看到我們使用rescale和crop對圖像進行處理。同時做了CHW和BGR的轉(zhuǎn)換,最后的圖像數(shù)據(jù)是NCHW。我們也統(tǒng)計了圖像均值,而不是簡單地將圖像減去128.
你會發(fā)現(xiàn)載入預(yù)處理模型是相當簡單的,僅僅需要幾行代碼就可以了。
- 讀取protobuf文件
with open("init_net.pb") as f:
init_net = f.read()
with open("predict_net.pb") as f:
predict_net = f.read()
- 使用
Predictor函數(shù)從protobuf中載入blobs數(shù)據(jù)
p = workspace.Predictor(init_net, predict_net)
- 跑網(wǎng)絡(luò)并獲取結(jié)果
results = p.run([img])
返回的結(jié)果是一個多維概率的矩陣,每一行是一個百分比,表示網(wǎng)絡(luò)識別出圖像屬于某一個物體的概率。當你使用前面那張花圖來測試時,網(wǎng)絡(luò)的返回應(yīng)該告訴你超過95的概率是雛菊。
Configuration
網(wǎng)絡(luò)設(shè)置如下:
# 你安裝caffe2的路徑
CAFFE2_ROOT = "~/caffe2"
# 假設(shè)是caffe2的子目錄
CAFFE_MODELS = "~/caffe2/caffe2/python/models"
#如果你有mean file,把它放在模型文件那個目錄里面
%matplotlib inline
from caffe2.proto import caffe2_pb2
import numpy as np
import skimage.io
import skimage.transform
from matplotlib import pyplot
import os
from caffe2.python import core, workspace
import urllib2
print("Required modules imported.")
傳遞圖像的路徑,或者網(wǎng)絡(luò)圖像的URL。物體編碼參照Alex Net,比如“985”代表是“雛菊”。其他編碼參照這里。
IMAGE_LOCATION = "https://cdn.pixabay.com/photo/2015/02/10/21/28/flower-631765_1280.jpg"
# 參數(shù)格式: folder, INIT_NET, predict_net, mean , input image size
MODEL = 'squeezenet', 'init_net.pb', 'predict_net.pb', 'ilsvrc_2012_mean.npy', 227
# AlexNet的物體編碼
codes = "https://gist.githubusercontent.com/aaronmarkham/cd3a6b6ac071eca6f7b4a6e40e6038aa/raw/9edb4038a37da6b5a44c3b5bc52e448ff09bfe5b/alexnet_codes"
print "Config set!"
處理圖像
def crop_center(img,cropx,cropy):
y,x,c = img.shape
startx = x//2-(cropx//2)
starty = y//2-(cropy//2)
return img[starty:starty+cropy,startx:startx+cropx]
def rescale(img, input_height, input_width):
print("Original image shape:" + str(img.shape) + " and remember it should be in H, W, C!")
print("Model's input shape is %dx%d") % (input_height, input_width)
aspect = img.shape[1]/float(img.shape[0])
print("Orginal aspect ratio: " + str(aspect))
if(aspect>1):
# landscape orientation - wide image
res = int(aspect * input_height)
imgScaled = skimage.transform.resize(img, (input_width, res))
if(aspect<1):
# portrait orientation - tall image
res = int(input_width/aspect)
imgScaled = skimage.transform.resize(img, (res, input_height))
if(aspect == 1):
imgScaled = skimage.transform.resize(img, (input_width, input_height))
pyplot.figure()
pyplot.imshow(imgScaled)
pyplot.axis('on')
pyplot.title('Rescaled image')
print("New image shape:" + str(imgScaled.shape) + " in HWC")
return imgScaled
print "Functions set."
# set paths and variables from model choice and prep image
CAFFE2_ROOT = os.path.expanduser(CAFFE2_ROOT)
CAFFE_MODELS = os.path.expanduser(CAFFE_MODELS)
# 均值最好從訓練集中計算得到
MEAN_FILE = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[3])
if not os.path.exists(MEAN_FILE):
mean = 128
else:
mean = np.load(MEAN_FILE).mean(1).mean(1)
mean = mean[:, np.newaxis, np.newaxis]
print "mean was set to: ", mean
# 輸入大小
INPUT_IMAGE_SIZE = MODEL[4]
# 確保所有文件存在
if not os.path.exists(CAFFE2_ROOT):
print("Houston, you may have a problem.")
INIT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[1])
print 'INIT_NET = ', INIT_NET
PREDICT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[2])
print 'PREDICT_NET = ', PREDICT_NET
if not os.path.exists(INIT_NET):
print(INIT_NET + " not found!")
else:
print "Found ", INIT_NET, "...Now looking for", PREDICT_NET
if not os.path.exists(PREDICT_NET):
print "Caffe model file, " + PREDICT_NET + " was not found!"
else:
print "All needed files found! Loading the model in the next block."
#載入一張圖像
img = skimage.img_as_float(skimage.io.imread(IMAGE_LOCATION)).astype(np.float32)
img = rescale(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
img = crop_center(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
print "After crop: " , img.shape
pyplot.figure()
pyplot.imshow(img)
pyplot.axis('on')
pyplot.title('Cropped')
# 轉(zhuǎn)換為CHW
img = img.swapaxes(1, 2).swapaxes(0, 1)
pyplot.figure()
for i in range(3):
pyplot.subplot(1, 3, i+1)
pyplot.imshow(img[i])
pyplot.axis('off')
pyplot.title('RGB channel %d' % (i+1))
#轉(zhuǎn)換為BGR
img = img[(2, 1, 0), :, :]
# 減均值
img = img * 255 - mean
# 增加batch size
img = img[np.newaxis, :, :, :].astype(np.float32)
print "NCHW: ", img.shape
狀態(tài)輸出:
Functions set.
mean was set to: 128
INIT_NET = /home/aaron/models/squeezenet/init_net.pb
PREDICT_NET = /home/aaron/models/squeezenet/predict_net.pb
Found /home/aaron/models/squeezenet/init_net.pb ...Now looking for /home/aaron/models/squeezenet/predict_net.pb
All needed files found! Loading the model in the next block.
Original image shape:(751, 1280, 3) and remember it should be in H, W, C!
Model's input shape is 227x227
Orginal aspect ratio: 1.70439414115
New image shape:(227, 386, 3) in HWC
After crop: (227, 227, 3)
NCHW: (1, 3, 227, 227)



既然圖像準備好了,那么放進CNN里面吧。打開protobuf,載入到workspace中,并跑起網(wǎng)絡(luò)。
#初始化網(wǎng)絡(luò)
with open(INIT_NET) as f:
init_net = f.read()
with open(PREDICT_NET) as f:
predict_net = f.read()
p = workspace.Predictor(init_net, predict_net)
# 進行預(yù)測
results = p.run([img])
# 把結(jié)果轉(zhuǎn)換為np矩陣
results = np.asarray(results)
print "results shape: ", results.shape
results shape: (1, 1, 1000, 1, 1)
看到1000沒。如果我們batch很大,那么這個矩陣將會很大,但是中間的維度仍然是1000。它記錄著模型預(yù)測的每一個類別的概率?,F(xiàn)在,讓我們繼續(xù)下一步。
results = np.delete(results, 1)#這句話不是很明白
index = 0
highest = 0
arr = np.empty((0,2), dtype=object)#創(chuàng)建一個0x2的矩陣?
arr[:,0] = int(10)#這是什么個意思?
arr[:,1:] = float(10)
for i, r in enumerate(results):
# imagenet的索引從1開始
i=i+1
arr = np.append(arr, np.array([[i,r]]), axis=0)
if (r > highest):
highest = r
index = i
print index, " :: ", highest
# top 3 結(jié)果
# sorted(arr, key=lambda x: x[1], reverse=True)[:3]
# 獲取 code list
response = urllib2.urlopen(codes)
for line in response:
code, result = line.partition(":")[::2]
if (code.strip() == str(index)):
print result.strip()[1:-2]
最后輸出:
985 :: 0.979059
daisy
譯者注:上面最后一段處理結(jié)果的代碼,譯者也不是很明白,有木有明白的同學在下面回復(fù)下?
轉(zhuǎn)載請注明出處:http://m.itdecent.cn/c/cf07b31bb5f2