在使用TensorFlow進(jìn)行建模、訓(xùn)練和預(yù)測(cè)時(shí),可以使用estimator這樣的高階函數(shù)方便使用?;镜奶茁肥牵?/p>
訓(xùn)練 fit
from tensorflow.contrib import learn
classifier = learn.Estimator(model_fn=lstm_model(stepLength, config['lstm_layers'], config['dense_layers']),
model_dir=config['log_dir'])
validation_monitor = learn.monitors.ValidationMonitor(x['val'], y['val'],
every_n_steps=config['print_steps'],
early_stopping_rounds=1000)
classifier.fit(x['train'], y['train'],monitors=[validation_monitor], batch_size=config['batch_size'], steps=len(x['train'])*config['epochs'])
- 初始化classifer
- 配置monitor
- 使用fit函數(shù)進(jìn)行訓(xùn)練
預(yù)測(cè) predict
classifier = learn.Estimator(model_fn=lstm_model(stepLength, config['lstm_layers'], config['dense_layers']),
model_dir=config['log_dir'])
pre = classifier.predict(x['test'])
- 初始化classifier,其中model_dir的配置需要指定為和訓(xùn)練相同的目錄
- 使用predict
根據(jù)官方文檔的步驟,這樣predict即可使用訓(xùn)練好的模型和參數(shù)。
但是實(shí)際結(jié)果是這樣predict會(huì)報(bào)錯(cuò),具體錯(cuò)誤如下:
ValueError: Tried to convert 'values' to a tensor and failed. Error: None values not supported.
修正后的predict
這其實(shí)是learn.predict的一個(gè)bug,目前社區(qū)還沒(méi)有fix,不過(guò)在issue-3208 給出了work around的方法,在predict之前使用eval,類似于聲明classifier并沒(méi)有實(shí)例化,在eval的時(shí)候,會(huì)將其實(shí)例化,這樣在之后使用predict的時(shí)候就不會(huì)再報(bào)Error: None values not supported.的錯(cuò)誤了。當(dāng)然這樣做的缺點(diǎn)是因?yàn)閑val會(huì)使用TF對(duì)模型進(jìn)行運(yùn)算至少一次,這樣會(huì)造成一定的性能損耗。
work around的方法
classifier = learn.Estimator(model_fn=lstm_model(stepLength, config['lstm_layers'], config['dense_layers']),
model_dir=config['log_dir'])
classifier.evaluate(x=x['val'],y=y['val'],steps=1)
pre = classifier.predict(x['test'])

work around.png