一、保存模型
方法一:通過Checkpoint保存
在Keras中有ModelCheckpoint函數(shù),調(diào)用該函數(shù)可以將每個(gè)epoch后的模型進(jìn)行保存。詳見官方文檔。具體的使用方法如下:
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=1)
參數(shù):
- filepath: 保存模型的路徑
- monitor: 被監(jiān)測的對象,比如
acc,loss,val_acc,... - verbose: 冗余的,如果想加上進(jìn)度條,就是1,如果不想,就是0
- set_best_only = True: 只保存最好的模型
- save_weights_only = False: 如果 True,那么只有模型的權(quán)重會(huì)被保存
model.save_weights(filepath), 否則的話,整個(gè)模型會(huì)被保存model.save(filepath) - mode = 'auto':
automax和min,如果使監(jiān)測acc,就是max,如果是監(jiān)測loss,就是min,auto就i會(huì)從監(jiān)測值中自己判斷 - period: 每個(gè)檢查點(diǎn)之間的間隔(訓(xùn)練輪數(shù))
方法二:通過save_model()保存
保存模型有兩種形式,一種是將模型整個(gè)都保存下來,包括權(quán)重和結(jié)構(gòu);另一種是只保留權(quán)重。
1. 1保存整個(gè)模型
可以使用model.save(filepath)將Keras的模型保存到HDF5文件中,該文件將包含:模型結(jié)構(gòu)、模型權(quán)重、配置項(xiàng)(優(yōu)化函數(shù)、優(yōu)化器)和優(yōu)化狀態(tài),允許準(zhǔn)確地從上次結(jié)束地地方繼續(xù)訓(xùn)練,詳見官方文檔。
from keras.model import load_model
model.save('my_model.h5')
model = load_model('my_model.h5')
1.2 保存部分模型
1.2.1 分別保存模型的權(quán)重和結(jié)構(gòu)
我們可以使用to_json()和to_yaml()方法將模型結(jié)構(gòu)保存到j(luò)osn文件或者yaml文件中。
'''方法1:保存為json'''
json_string = model.to_json()
open('model_architecture_1.json', 'w').write(json_string) #重命名
#從json中讀出數(shù)據(jù)
from keras.models import model_from_json
model = model_from_json(json_string)
'''方法2:保存為yaml'''
yaml_string = model.to_yaml()
open('model_arthitecture_2.yaml', 'w').write(yaml_string) #重命名
#從json中讀出數(shù)據(jù)
from keras.models import model_from_yaml
model = model_from_yaml(yaml_string)
1.2.2 只保存模型權(quán)重
通過model.save_weights('my_model_weights.h5')將權(quán)重保存在HDF5文件中。如果有可以實(shí)例化模型的代碼。則可以將保存的權(quán)重加載到相同結(jié)構(gòu)的模型中:
model.load_weights('my_model_weights.h5')
二、導(dǎo)入模型
如果我們想導(dǎo)入訓(xùn)練好的最好的模型來進(jìn)行預(yù)測,最好使用方法一,將最好的模型保存下來然后導(dǎo)入進(jìn)行預(yù)測,如果想接著上一次的模型繼續(xù)訓(xùn)練,可以兩種方法都可以。
from keras.models import load_model
model = load_model(filepath)
保存模型可能會(huì)出錯(cuò)的地方:
- filepath不會(huì)自己建立文件夾,例如
checkpointer = ModelCheckpoint(filepath='tmp\model.h5'),如果同目錄下沒有tmp文件夾,程序?qū)?huì)出錯(cuò),模型無法保存。 - 在進(jìn)行訓(xùn)練的時(shí)候,需要把checkpointer加進(jìn)去回調(diào)函數(shù)
callbacks中,并用中括號擴(kuò)起來,即model.fit(x, y, callbacks=[checkpointer])
附:保存模型圖
我們可以通過model.summary()將模型的結(jié)構(gòu)打印出來,另外,我們可以通過plot_model(model, 'model_plot.png')將模型的基本結(jié)構(gòu)框圖保存下來。另外你也可以使用keras.utils.vis_utils模塊將模型的詳細(xì)結(jié)構(gòu)框圖保存下來。
from keras.utils import plot_model
plot_model(model, to_file='model.png')