tensorflow提供兩種模型格式
- checkpoint:依賴于創(chuàng)建模型的代碼
- SavedModel:與模型代碼無關(guān)
這里盡介紹checkpoint
1. 保存經(jīng)過部分訓(xùn)練的模型
Estimator自動(dòng)將如下內(nèi)容寫入磁盤
- checkpoints: 訓(xùn)練期間所創(chuàng)建的模型版本
- event files: 包含有TensorBoard用于創(chuàng)建可視化圖標(biāo)的全部信息
要指定模型的頂級(jí)存儲(chǔ)目錄,可以使用Estimator構(gòu)造函數(shù)的可選參數(shù)model_dir,設(shè)置代碼如下所示
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir="./models_dir")
當(dāng)調(diào)用Estimator的train方法時(shí),Estimator會(huì)將checkpoint和其他文件保存到model_dir目錄中,保存之后,這個(gè)目錄中的文件如下所示:
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
這個(gè)目錄存儲(chǔ)的是Estimator在第一步訓(xùn)練開始和第200不訓(xùn)練結(jié)束時(shí)創(chuàng)建的checkpoints
2. checkpoint頻率
默認(rèn)情況下,Estimator按照如下時(shí)間將checkpoint保存到model_dir中
- 每600秒保存一次
- 在
train方法開始以及完成時(shí)都要保存checkpoint - 在目錄中最多保留5個(gè)最近的checkpoints
可以通過如下步驟來更改默認(rèn)設(shè)置:
- 創(chuàng)建
RunConfig對(duì)象來自定義設(shè)置 - 在實(shí)例化Estimator時(shí),將該
RunConfig對(duì)象傳遞個(gè)Estimatro的config參數(shù)
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60,
keep_checkpoint_max = 10,
)
3. 從checkpoint中恢復(fù)模型
在第一次調(diào)用Estimator的train方法時(shí),Tensorflow會(huì)將checkpoint保存到model_dir中,隨后每次調(diào)用Estimator的train、eval或者predict方法時(shí),都會(huì)發(fā)生下列情況:
- Esitmator運(yùn)行
model_fun()構(gòu)建模型圖 - Estimator根據(jù)最近寫入的checkpoint中存儲(chǔ)的數(shù)據(jù)來初始化新模型的權(quán)重
4. 避免不當(dāng)恢復(fù)
通過checkpoint恢復(fù)模型的狀態(tài)必須保證模型和checkpoint保存的兼容才可以。例如我們訓(xùn)練了一個(gè)DNNClassifier Estimator,它包含有2個(gè)隱藏層且每層都有10個(gè)節(jié)點(diǎn),經(jīng)過訓(xùn)練兵保存了checkpoint到model_dir中。后續(xù)在訓(xùn)練的時(shí)候,假如將代碼中的隱藏層修改為了每層20個(gè)節(jié)點(diǎn),這樣用這樣的Estimator調(diào)用train時(shí)就回報(bào)錯(cuò),因?yàn)閏heckpoint保存的模型結(jié)構(gòu)與代碼中的模型是不兼容的。這一點(diǎn)切記。