TensorFlow Lite學習筆記2:生成TFLite模型文件

簡介

在桌面PC或是服務器上使用TensorFlow訓練出來的模型文件,不能直接用在TFLite上運行,需要使用離線工具先轉(zhuǎn)成.tflite文件。筆者發(fā)現(xiàn)官方文檔中很多細節(jié)介紹的都不太明確,在使用過程中需要不斷嘗試。我把自己的嘗試過的步驟分享出來,希望能幫助大家節(jié)省時間。

具體說來,tflite文件的生成大致分為3步:

1. 在算法訓練的腳本中保存圖模型文件(GraphDef)和變量文件(CheckPoint)。

2. 利用freeze_graph工具生成frozen的graphdef文件。

3. 利用toco(Tensorflow Optimizing COnverter)工具,生成最終的tflite文件。

圖1. 生成tflite文件的整個流程示意圖

第1步:導出圖模型文件和變量文件

在你的算法的訓練或推理任務的腳本中,利用tensorflow.train中的write_graph和saver API來導出GraphDef及Checkpoint文件。

圖2. TensorFlow中導出GraphDef文件和Checkpoint文件

其中,tf.train.write_graph一行將導出模型的GraphDef文件,實際上保存了訓練的神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)圖信息。存儲格式為protobuffer,所以文件名后綴為pb。

圖3. 導出的GraphDef文件

tf.train.saver.save一行導出的是模型的變量文件,實際上保存了整個圖中所有變量目前的取值。

圖4. 導出的checkpoint文件

如圖4所示,實際上產(chǎn)生了4個文件。在后續(xù)步驟中需要用到的是nsfw_model.ckpt.data-00000-of-00001這個文件,保存了當前神經(jīng)網(wǎng)絡(luò)各參數(shù)的取值。

第2步:生成frozen的graphdef文件

在此步驟中,使用Tensorflow源代碼中自帶的freeze_graph工具,生成一個frozen的GraphDef文件。

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/data/deep_learning/nsfw/model/nsfw-graph.pb --input_checkpoint=/data/deep_learning/nsfw/model/nsfw_model.ckpt --input_binary=true --output_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb --output_node_names=predictions

這里有兩個地方容易搞錯。第一個地方,input_checkpoint參數(shù)實際上用到的文件應該是nsfw_model.ckpt.data-00000-of-00001,但是在指定文件名的時候只需要指定nsfw_model.ckpt即可。第二個地方,是output_node_names參數(shù),此處指定的是神經(jīng)網(wǎng)絡(luò)圖中的輸出節(jié)點的名字,是在訓練階段的Python腳本中定義的。如下圖所示,在定義網(wǎng)絡(luò)結(jié)構(gòu)時,輸出節(jié)點的名稱為"predictions"。則最終output_node_names需要指定為“predictions”。

圖5. output_node_names參數(shù)取值與網(wǎng)絡(luò)模型定義時的名字要對應

當然,也可以利用summarize_graph打印出模型的輸入和輸出節(jié)點,如:

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb

圖6. 輸入節(jié)點名稱為input
圖7. 輸出節(jié)點名稱為predictions

第3步:生成最終的tflite文件

在此步驟中,使用Tensorflow源代碼中自帶的toco工具,生成一個可供TensorFlow Lite框架使用tflite文件。其中input_arrays和output_arrays的名稱需要與定義網(wǎng)絡(luò)類型時取的名稱保持一致。

bazel run --config=opt tensorflow/contrib/lite/toco:toco --input_file=/data/deep_learning/nsfw/model/frozen_nsfw.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=/data/deep_learning/nsfw/model/nsfw.lite --inference_type=FLOAT --input_type=FLOAT --input_arrays=input --output_arrays=predictions --input_shapes=1,224,224,3

生成的nsfw.lite文件即可用于TensorFlow Lite應用。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關(guān)閱讀更多精彩內(nèi)容

  • 前言 本文主要參考了幾篇文章,搭建了一個在iOS上跑Tensorflow MNIST模型的demo,本文會給出一個...
    Thanatos_defy閱讀 2,462評論 9 10
  • Spring Cloud為開發(fā)人員提供了快速構(gòu)建分布式系統(tǒng)中一些常見模式的工具(例如配置管理,服務發(fā)現(xiàn),斷路器,智...
    卡卡羅2017閱讀 136,695評論 19 139
  • 同樣對待孩子。有人說“我要給他/她全世界最好的?!庇腥苏f“我要讓他/她成為全世界最好的?!?種下什么因,就一定會得...
    sunfeng0912閱讀 267評論 0 3
  • ??藸栠€在悲鳴。 我望著灰暗的天空,身上使不出一絲力氣,直到雨落在眼睛里,眼瞼才掙扎著恢復生機。此刻我該想些什么,...
    不才子閱讀 283評論 0 1
  • 都說程序員上輩子是江湖中飛檐走壁的俠客,在江湖中長劍在手,快意恩仇,瀟瀟灑灑,仗義人生。所以這輩子再也不想過那種漂...
    非著名程序員閱讀 1,383評論 4 8

友情鏈接更多精彩內(nèi)容