簡介
在桌面PC或是服務器上使用TensorFlow訓練出來的模型文件,不能直接用在TFLite上運行,需要使用離線工具先轉(zhuǎn)成.tflite文件。筆者發(fā)現(xiàn)官方文檔中很多細節(jié)介紹的都不太明確,在使用過程中需要不斷嘗試。我把自己的嘗試過的步驟分享出來,希望能幫助大家節(jié)省時間。
具體說來,tflite文件的生成大致分為3步:
1. 在算法訓練的腳本中保存圖模型文件(GraphDef)和變量文件(CheckPoint)。
2. 利用freeze_graph工具生成frozen的graphdef文件。
3. 利用toco(Tensorflow Optimizing COnverter)工具,生成最終的tflite文件。

第1步:導出圖模型文件和變量文件
在你的算法的訓練或推理任務的腳本中,利用tensorflow.train中的write_graph和saver API來導出GraphDef及Checkpoint文件。

其中,tf.train.write_graph一行將導出模型的GraphDef文件,實際上保存了訓練的神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)圖信息。存儲格式為protobuffer,所以文件名后綴為pb。

tf.train.saver.save一行導出的是模型的變量文件,實際上保存了整個圖中所有變量目前的取值。

如圖4所示,實際上產(chǎn)生了4個文件。在后續(xù)步驟中需要用到的是nsfw_model.ckpt.data-00000-of-00001這個文件,保存了當前神經(jīng)網(wǎng)絡(luò)各參數(shù)的取值。
第2步:生成frozen的graphdef文件
在此步驟中,使用Tensorflow源代碼中自帶的freeze_graph工具,生成一個frozen的GraphDef文件。
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/data/deep_learning/nsfw/model/nsfw-graph.pb --input_checkpoint=/data/deep_learning/nsfw/model/nsfw_model.ckpt --input_binary=true --output_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb --output_node_names=predictions
這里有兩個地方容易搞錯。第一個地方,input_checkpoint參數(shù)實際上用到的文件應該是nsfw_model.ckpt.data-00000-of-00001,但是在指定文件名的時候只需要指定nsfw_model.ckpt即可。第二個地方,是output_node_names參數(shù),此處指定的是神經(jīng)網(wǎng)絡(luò)圖中的輸出節(jié)點的名字,是在訓練階段的Python腳本中定義的。如下圖所示,在定義網(wǎng)絡(luò)結(jié)構(gòu)時,輸出節(jié)點的名稱為"predictions"。則最終output_node_names需要指定為“predictions”。

當然,也可以利用summarize_graph打印出模型的輸入和輸出節(jié)點,如:
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb


第3步:生成最終的tflite文件
在此步驟中,使用Tensorflow源代碼中自帶的toco工具,生成一個可供TensorFlow Lite框架使用tflite文件。其中input_arrays和output_arrays的名稱需要與定義網(wǎng)絡(luò)類型時取的名稱保持一致。
bazel run --config=opt tensorflow/contrib/lite/toco:toco --input_file=/data/deep_learning/nsfw/model/frozen_nsfw.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=/data/deep_learning/nsfw/model/nsfw.lite --inference_type=FLOAT --input_type=FLOAT --input_arrays=input --output_arrays=predictions --input_shapes=1,224,224,3
生成的nsfw.lite文件即可用于TensorFlow Lite應用。