Keras如何保存訓(xùn)練模型

一、保存模型

方法一:通過Checkpoint保存

在Keras中有ModelCheckpoint函數(shù),調(diào)用該函數(shù)可以將每個(gè)epoch后的模型進(jìn)行保存。詳見官方文檔。具體的使用方法如下:

from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath, 
                            monitor='val_loss', 
                            verbose=0, 
                            save_best_only=False, 
                            save_weights_only=False, 
                            mode='auto', 
                            period=1)

參數(shù):

  • filepath: 保存模型的路徑
  • monitor: 被監(jiān)測的對象,比如acc, loss, val_acc,...
  • verbose: 冗余的,如果想加上進(jìn)度條,就是1,如果不想,就是0
  • set_best_only = True: 只保存最好的模型
  • save_weights_only = False: 如果 True,那么只有模型的權(quán)重會(huì)被保存 model.save_weights(filepath), 否則的話,整個(gè)模型會(huì)被保存 model.save(filepath)
  • mode = 'auto': auto maxmin,如果使監(jiān)測acc,就是max,如果是監(jiān)測loss,就是min,auto就i會(huì)從監(jiān)測值中自己判斷
  • period: 每個(gè)檢查點(diǎn)之間的間隔(訓(xùn)練輪數(shù))

方法二:通過save_model()保存

保存模型有兩種形式,一種是將模型整個(gè)都保存下來,包括權(quán)重和結(jié)構(gòu);另一種是只保留權(quán)重。

1. 1保存整個(gè)模型

可以使用model.save(filepath)將Keras的模型保存到HDF5文件中,該文件將包含:模型結(jié)構(gòu)、模型權(quán)重、配置項(xiàng)(優(yōu)化函數(shù)、優(yōu)化器)和優(yōu)化狀態(tài),允許準(zhǔn)確地從上次結(jié)束地地方繼續(xù)訓(xùn)練,詳見官方文檔。

from keras.model import load_model
model.save('my_model.h5')

model = load_model('my_model.h5')

1.2 保存部分模型

1.2.1 分別保存模型的權(quán)重和結(jié)構(gòu)

我們可以使用to_json()to_yaml()方法將模型結(jié)構(gòu)保存到j(luò)osn文件或者yaml文件中。

'''方法1:保存為json'''
json_string = model.to_json()
open('model_architecture_1.json', 'w').write(json_string)   #重命名

#從json中讀出數(shù)據(jù)
from keras.models import model_from_json
model = model_from_json(json_string)
'''方法2:保存為yaml'''
yaml_string = model.to_yaml()
open('model_arthitecture_2.yaml', 'w').write(yaml_string)  #重命名

#從json中讀出數(shù)據(jù)
from keras.models import model_from_yaml
model = model_from_yaml(yaml_string)

1.2.2 只保存模型權(quán)重

通過model.save_weights('my_model_weights.h5')將權(quán)重保存在HDF5文件中。如果有可以實(shí)例化模型的代碼。則可以將保存的權(quán)重加載到相同結(jié)構(gòu)的模型中:

model.load_weights('my_model_weights.h5')

二、導(dǎo)入模型

如果我們想導(dǎo)入訓(xùn)練好的最好的模型來進(jìn)行預(yù)測,最好使用方法一,將最好的模型保存下來然后導(dǎo)入進(jìn)行預(yù)測,如果想接著上一次的模型繼續(xù)訓(xùn)練,可以兩種方法都可以。

from keras.models import load_model
model = load_model(filepath)

保存模型可能會(huì)出錯(cuò)的地方:

  1. filepath不會(huì)自己建立文件夾,例如checkpointer = ModelCheckpoint(filepath='tmp\model.h5'),如果同目錄下沒有tmp文件夾,程序?qū)?huì)出錯(cuò),模型無法保存。
  2. 在進(jìn)行訓(xùn)練的時(shí)候,需要把checkpointer加進(jìn)去回調(diào)函數(shù)callbacks中,并用中括號擴(kuò)起來,即model.fit(x, y, callbacks=[checkpointer])

附:保存模型圖

我們可以通過model.summary()將模型的結(jié)構(gòu)打印出來,另外,我們可以通過plot_model(model, 'model_plot.png')將模型的基本結(jié)構(gòu)框圖保存下來。另外你也可以使用keras.utils.vis_utils模塊將模型的詳細(xì)結(jié)構(gòu)框圖保存下來。

from keras.utils import plot_model
plot_model(model, to_file='model.png')

參考文檔:
官方文檔
Keras保存最好的模型
Keras框架下的保存模型和加載模型

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容