????????看過(guò) TensorFlow-slim 訓(xùn)練 CNN 分類(lèi)模型(續(xù)) 及其相關(guān)系列文章的讀者應(yīng)該已經(jīng)感受到了 tf.contrib.slim 在訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)方面的極其方便之處,特別是它讓構(gòu)建模型變得非常直觀。但不可忽視的是,它還存在一在很大的缺點(diǎn),就是它在訓(xùn)練模型的同時(shí),沒(méi)有開(kāi)放接口讓用戶可以快速且方便的在驗(yàn)證集上測(cè)試模型的性能。比如,現(xiàn)在有訓(xùn)練集和測(cè)試集,我們希望在訓(xùn)練集上訓(xùn)練模型,訓(xùn)練的同時(shí)又希望能時(shí)時(shí)看到它在測(cè)試集上的效果,以求快速的知道模型是否存在過(guò)擬合等問(wèn)題。但 tf.contrib.slim 的訓(xùn)練函數(shù) slim.learning.train 并沒(méi)有提供驗(yàn)證數(shù)據(jù)接口,因此除非自己另外補(bǔ)充代碼,否則訓(xùn)練的同時(shí)并不能監(jiān)控模型在測(cè)試集上的性能。這個(gè)缺陷對(duì)于那些想快速調(diào)參的人來(lái)說(shuō)是比較致命的,幸好 TensorFlow 推出了新的高級(jí) API tf.estimator.Estimator 可以彌補(bǔ)這個(gè)缺陷。簡(jiǎn)單來(lái)說(shuō),estimator 這個(gè)接口就是為了方便模型的訓(xùn)練過(guò)程而開(kāi)發(fā)的,它可以同時(shí)訓(xùn)練和驗(yàn)證模型,讓訓(xùn)練過(guò)程更簡(jiǎn)單可控。
????????本文意在結(jié)合 slim 在構(gòu)建模型時(shí)的易用之處,以及 estimator 在訓(xùn)練模型時(shí)的方便之處,取長(zhǎng)補(bǔ)短,兼容并蓄,進(jìn)一步提升深度學(xué)習(xí)項(xiàng)目實(shí)現(xiàn)的效率。
????????本文的數(shù)據(jù)使用貓狗分類(lèi)數(shù)據(jù)集(kaggle比賽貓狗數(shù)據(jù)集百度網(wǎng)盤(pán)分享),其中貓對(duì)應(yīng)類(lèi)標(biāo)號(hào) 0,狗對(duì)應(yīng)類(lèi)標(biāo)號(hào) 1,所有代碼請(qǐng)?jiān)L問(wèn) GitHub: slim_cnn_estimator。
一、模型定義
????????模型定義仍然使用 tf.contrib.slim 來(lái)寫(xiě)(命名為 model.py):
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:12 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from tensorflow.contrib.slim import nets
import preprocessing
slim = tf.contrib.slim
class Model(object):
"""xxx definition."""
def __init__(self, is_training,
num_classes=2,
fixed_resize_side=256,
default_image_size=224):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
num_classes: Number of classes.
"""
self._num_classes = num_classes
self._is_training = is_training
self._fixed_resize_side = fixed_resize_side
self._default_image_size = default_image_size
@property
def num_classes(self):
return self._num_classes
def preprocess(self, inputs):
"""preprocessing.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
preprocessed_inputs = preprocessing.preprocess_images(
inputs, self._default_image_size, self._default_image_size,
resize_side_min=self._fixed_resize_side,
is_training=self._is_training,
border_expand=False, normalize=False,
preserving_aspect_ratio_resize=False)
preprocessed_inputs = tf.cast(preprocessed_inputs, tf.float32)
return preprocessed_inputs
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
net, endpoints = nets.resnet_v1.resnet_v1_50(
preprocessed_inputs, num_classes=None,
is_training=self._is_training)
net = tf.squeeze(net, axis=[1, 2])
logits = slim.fully_connected(net, num_outputs=self.num_classes,
activation_fn=None,
scope='Predict/logits')
return {'logits': logits}
def postprocess(self, prediction_dict):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
postprocessed_dict = {}
for logits_name, logits in prediction_dict.items():
logits = tf.nn.softmax(logits)
classes = tf.argmax(logits, axis=1)
classes_name = logits_name.replace('logits', 'classes')
postprocessed_dict[logits_name] = logits
postprocessed_dict[classes_name] = classes
return postprocessed_dict
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each branch prediction.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
logits = prediction_dict.get('logits')
slim.losses.sparse_softmax_cross_entropy(logits, groundtruth_lists)
loss = slim.losses.get_total_loss()
loss_dict = {'loss': loss}
return loss_dict
def accuracy(self, postprocessed_dict, groundtruth_lists):
"""Calculate accuracy.
Args:
postprocessed_dict: A dictionary containing the postprocessed
results
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
accuracy: The scalar accuracy.
"""
classes = postprocessed_dict['classes']
accuracy = tf.reduce_mean(
tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
return accuracy
網(wǎng)絡(luò)結(jié)構(gòu)非常簡(jiǎn)單(見(jiàn) predict 函數(shù)),只替換了 ResNet-50 的最后一個(gè)全連接層,使用 slim 寫(xiě)神經(jīng)網(wǎng)絡(luò)的模型可以參考文章 TensorFlow-slim 訓(xùn)練 CNN 分類(lèi)模型。其它函數(shù)顧名思義,也都非常簡(jiǎn)單。
二、模型訓(xùn)練
????????利用 tf.estimator 訓(xùn)練模型時(shí)需要寫(xiě)兩個(gè)重要的函數(shù),一個(gè)用于數(shù)據(jù)輸入的函數(shù)(input_fn),另一個(gè)用于模型創(chuàng)建的函數(shù)(model_fn)。下面逐一來(lái)說(shuō)明。(這里沿用以前文章 ,數(shù)據(jù)格式仍然采用 TFRecord)。
????????首先我們從調(diào)用順序來(lái)介紹一下大概的訓(xùn)練過(guò)程(完整官方文檔:tf.estimator):
- 使用
tf.estimator.train_and_evaluate啟動(dòng)訓(xùn)練和驗(yàn)證過(guò)程。該函數(shù)的完整形式是:
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
其中 estimator 是一個(gè) tf.estimator.Estimator 對(duì)象,用于指定模型函數(shù)以及其它相關(guān)參數(shù);train_spec 是一個(gè) tf.estimator.TrainSpec 對(duì)象,用于指定訓(xùn)練的輸入函數(shù)以及其它參數(shù);eval_spec 是一個(gè) tf.estimator.EvalSpec 對(duì)象,用于指定驗(yàn)證的輸入函數(shù)以及其它參數(shù)。
- 使用
tf.estimator.Estimator定義 Estimator 實(shí)例 estimator。類(lèi) Estimator 的完整形式是:
tf.estimator.Estimator(model_fn, model_dir=None, config=None,
params=None, warm_start_from=None)
其中 model_fn 是模型函數(shù);model_dir 是訓(xùn)練時(shí)模型保存的路徑;config 是 tf.estimator.RunConfig 的配置對(duì)象;params 是傳入 model_fn 的超參數(shù)字典;warm_start_from 或者是一個(gè)預(yù)訓(xùn)練文件的路徑,或者是一個(gè) tf.estimator.WarmStartSettings 對(duì)象,用于完整的配置熱啟動(dòng)參數(shù)。
- 使用
tf.estimator.TrainSpec指定訓(xùn)練輸入函數(shù)及相關(guān)參數(shù)。該類(lèi)的完整形式是:
tf.estimator.TrainSpec(input_fn, max_steps, hooks)
其中 input_fn 用來(lái)提供訓(xùn)練時(shí)的輸入數(shù)據(jù);max_steps 指定總共訓(xùn)練多少步;hooks 是一個(gè) tf.train.SessionRunHook 對(duì)象,用來(lái)配置分布式訓(xùn)練等參數(shù)。
- 使用
tf.estimator.EvalSpec指定驗(yàn)證輸入函數(shù)及相關(guān)參數(shù)。該類(lèi)的完整形式是:
tf.estimator.EvalSpec(
input_fn,
steps=100,
name=None,
hooks=None,
exporters=None,
start_delay_secs=120,
throttle_secs=600)
其中 input_fn 用來(lái)提供驗(yàn)證時(shí)的輸入數(shù)據(jù);steps 指定總共驗(yàn)證多少步(一般設(shè)定為 None 即可);hooks 用來(lái)配置分布式訓(xùn)練等參數(shù);exporters 是一個(gè) Exporter 迭代器,會(huì)參與到每次的模型驗(yàn)證;start_delay_secs 指定多少秒之后開(kāi)始模型驗(yàn)證;throttle_secs 指定多少秒之后重新開(kāi)始新一輪模型驗(yàn)證(當(dāng)然,如果沒(méi)有新的模型斷點(diǎn)保存,則該數(shù)值秒之后不會(huì)進(jìn)行模型驗(yàn)證,因此這是新一輪模型驗(yàn)證需要等待的最小秒數(shù))。
- 定義模型函數(shù)
model_fn,返回類(lèi) tf.estimator.EstimatorSpec 的一個(gè)實(shí)例。model_fn 的完整定義形式是(函數(shù)名任?。?/li>
def create_model_fn(features, labels, mode, params=None):
params = params or {}
loss, train_op, ... = None, None, ...
prediction_dict = ...
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss = ...
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = ...
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=prediction_dict,
loss=loss,
train_op=train_op,
...)
其中 features,labels 可以是一個(gè)張量,也可以是由張量組成的一個(gè)字典;mode 指定訓(xùn)練模式,可以取 (TRAIN, EVAL, PREDICT)三者之一;params 是一個(gè)(可要可不要的)字典,指定其它超參數(shù)。model_fn 必須定義模型的預(yù)測(cè)結(jié)果、損失、優(yōu)化器等,它返回類(lèi) tf.estimator.EstimatorSpec 的一個(gè)對(duì)象。
- 類(lèi)
tf.estimator.EstimatorSpec的完整形式是:
tf.estimator.EstimatorSpec(
mode,
predictions=None,
loss=None,
train_op=None,
eval_metric_ops=None,
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None)
其中 mode 指定當(dāng)前是處于訓(xùn)練、驗(yàn)證還是預(yù)測(cè)狀態(tài);predictions 是預(yù)測(cè)的一個(gè)張量,或者是由張量組成的一個(gè)字典;loss 是損失張量;train_op 指定優(yōu)化操作;eval_metric_ops 指定各種評(píng)估度量的字典,這個(gè)字典的值必須是如下兩種形式:
- Metric 類(lèi)的實(shí)例;
- 調(diào)用某個(gè)評(píng)估度量函數(shù)的結(jié)果對(duì) (metric_tensor, update_op);
參數(shù) export_outputs 只用于模型保存,描述了導(dǎo)出到 SavedModel 的輸出格式;參數(shù) scaffold 是一個(gè) tf.train.Scaffold 對(duì)象,可以在訓(xùn)練階段初始化、保存等時(shí)使用。
- 定義輸入函數(shù)
input_fn,返回如下兩種格式之一:
- tf.data.Dataset 對(duì)象:這個(gè)對(duì)象的輸出必須是元組隊(duì) (features, labels),而且必須滿足下一條返回格式的同等約束;
- 元組 (features, labels):features 以及 labels 都必須是一個(gè)張量或由張量組成的字典。
????????了解了這些之后,我們來(lái)看使用 tf.estimator 的訓(xùn)練代碼(命名為:train.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 19:27:44 2018
@author: shirhe-lyh
Train a CNN model to classifying 10 digits.
Example Usage:
---------------
python3 train.py \
--train_record_path: Path to training tfrecord file.
--val_record_path: Path to validation tfrecord file.
--model_dir: Path to log directory.
"""
import functools
import logging
import os
import tensorflow as tf
import exporter
import model
slim = tf.contrib.slim
flags = tf.app.flags
flags.DEFINE_string('gpu_indices', '0', 'The index of gpus to used.')
flags.DEFINE_string('train_record_path',
'./datasets/train.record',
'Path to training tfrecord file.')
flags.DEFINE_string('val_record_path',
'./datasets/val.record',
'Path to validation tfrecord file.')
flags.DEFINE_string('checkpoint_path',
None,
'Path to a pretrained model.')
flags.DEFINE_string('model_dir', './training', 'Path to log directory.')
flags.DEFINE_float('keep_checkpoint_every_n_hours',
0.2,
'Save model checkpoint every n hours.')
flags.DEFINE_string('learning_rate_decay_type',
'exponential',
'Specifies how the learning rate is decayed. One of '
'"fixed", "exponential", or "polynomial"')
flags.DEFINE_float('learning_rate',
0.0001,
'Initial learning rate.')
flags.DEFINE_float('end_learning_rate',
0.000001,
'The minimal end learning rate used by a polynomial decay '
'learning rate.')
flags.DEFINE_float('decay_steps',
1000,
'Number of epochs after which learning rate decays. '
'Note: this flag counts epochs per clone but aggregates '
'per sync replicas. So 1.0 means that each clone will go '
'over full epoch individually, but replicas will go once '
'across all replicas.')
flags.DEFINE_float('learning_rate_decay_factor',
0.5,
'Learning rate decay factor.')
flags.DEFINE_integer('num_classes', 2, 'Number of classes.')
flags.DEFINE_integer('batch_size', 64, 'Batch size.')
flags.DEFINE_integer('num_steps', 5000, 'Number of steps.')
flags.DEFINE_integer('input_size', 224, 'Number of steps.')
FLAGS = flags.FLAGS
def get_decoder():
"""Returns a TFExampleDecoder."""
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/class/label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64))}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded',
format_key='image/format',
channels=3),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
return decoder
def transform_data(image):
size = FLAGS.input_size + 32
image = tf.squeeze(tf.image.resize_bilinear([image], size=[size, size]))
image = tf.to_float(image)
return image
def read_dataset(file_read_fun, input_files, num_readers=1, shuffle=False,
num_epochs=0, read_block_length=32, shuffle_buffer_size=2048):
"""Reads a dataset, and handles repeatition and shuffling.
This function and the following are modified from:
https://github.com/tensorflow/models/blob/master/research/
object_detection/builders/dataset_builder.py
Args:
file_read_fun: Function to use in tf.contrib.data.parallel_iterleave,
to read every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
Returns:
A tf.data.Dataset of (undecoded) tf-records.
"""
# Shard, shuffle, and read files
filenames = tf.gfile.Glob(input_files)
if num_readers > len(filenames):
num_readers = len(filenames)
tf.logging.warning('num_readers has been reduced to %d to match input '
'file shards.' % num_readers)
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if shuffle:
filename_dataset = filename_dataset.shuffle(100)
elif num_readers > 1:
tf.logging.warning('`shuffle` is false, but the input data stream is '
'still slightly shuffled since `num_readers` > 1.')
filename_dataset = filename_dataset.repeat(num_epochs or None)
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_fun,
cycle_length=num_readers,
block_length=read_block_length,
sloppy=shuffle))
if shuffle:
records_dataset = records_dataset.shuffle(shuffle_buffer_size)
return records_dataset
def create_input_fn(record_paths, batch_size=64,
num_epochs=0, num_parallel_batches=8,
num_prefetch_batches=2):
"""Create a train or eval `input` function for `Estimator`.
Args:
record_paths: A list contains the paths of tfrecords.
Returns:
`input_fn` for `Estimator` in TRAIN/EVAL mode.
"""
def _input_fn():
decoder = get_decoder()
def decode(value):
keys = decoder.list_items()
tensors = decoder.decode(value)
tensor_dict = dict(zip(keys, tensors))
image = tensor_dict.get('image')
image = transform_data(image)
features_dict = {'image': image}
return features_dict, tensor_dict.get('label')
dataset = read_dataset(
functools.partial(tf.data.TFRecordDataset,
buffer_size=8 * 1000 * 1000),
input_files=record_paths,
num_epochs=num_epochs)
if batch_size:
num_parallel_calles = batch_size * num_parallel_batches
else:
num_parallel_calles = num_parallel_batches
dataset = dataset.map(decode, num_parallel_calls=num_parallel_calles)
if batch_size:
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.prefetch(num_prefetch_batches)
return dataset
return _input_fn
def create_predict_input_fn():
"""Creates a predict `input` function for `Estimator`.
Modified from:
https://github.com/tensorflow/models/blob/master/research/
object_detection/inputs.py
Returns:
`input_fn` for `Estimator` in PREDICT mode.
"""
def _predict_input_fn():
"""Decodes serialized tf.Examples and returns `ServingInputReceiver`.
Returns:
`ServingInputReceiver`.
"""
example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example')
decoder = get_decoder()
keys = decoder.list_items()
tensors = decoder.decode(example, items=keys)
tensor_dict = dict(zip(keys, tensors))
image = tensor_dict.get('image')
image = transform_data(image)
images = tf.expand_dims(image, axis=0)
return tf.estimator.export.ServingInputReceiver(
features={'image': images},
receiver_tensors={'serialized_example': example})
return _predict_input_fn
def create_model_fn(features, labels, mode, params=None):
"""Constructs the classification model.
Modifed from:
https://github.com/tensorflow/models/blob/master/research/
object_detection/model_lib.py.
Args:
features: A 4-D float32 tensor with shape [batch_size, height,
width, channels] representing a batch of images. (Support dict)
labels: A 1-D int32 tensor with shape [batch_size] representing
the labels of each image. (Support dict)
mode: Mode key for tf.estimator.ModeKeys.
params: Parameter dictionary passed from the estimator.
Returns:
An `EstimatorSpec` the encapsulates the model and its serving
configurations.
"""
params = params or {}
loss, acc, train_op, export_outputs = None, None, None, None
is_training = mode == tf.estimator.ModeKeys.TRAIN
cls_model = model.Model(is_training=is_training,
num_classes=FLAGS.num_classes)
preprocessed_inputs = cls_model.preprocess(features.get('image'))
prediction_dict = cls_model.predict(preprocessed_inputs)
postprocessed_dict = cls_model.postprocess(prediction_dict)
if mode == tf.estimator.ModeKeys.TRAIN:
if FLAGS.checkpoint_path:
init_variables_from_checkpoint()
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss_dict = cls_model.loss(prediction_dict, labels)
loss = loss_dict['loss']
classes = postprocessed_dict['classes']
acc = tf.reduce_mean(tf.cast(tf.equal(classes, labels), 'float'))
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', acc)
scaffold = None
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = configure_learning_rate(FLAGS.decay_steps,
global_step)
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
momentum=0.9)
train_op = slim.learning.create_train_op(loss, optimizer,
summarize_gradients=True)
keep_checkpoint_every_n_hours = FLAGS.keep_checkpoint_every_n_hours
saver = tf.train.Saver(
sharded=True,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
scaffold = tf.train.Scaffold(saver=saver)
eval_metric_ops = None
if mode == tf.estimator.ModeKeys.EVAL:
accuracy = tf.metrics.accuracy(labels=labels, predictions=classes)
eval_metric_ops = {'Accuracy': accuracy}
if mode == tf.estimator.ModeKeys.PREDICT:
export_output = exporter._add_output_tensor_nodes(postprocessed_dict)
export_outputs = {
tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
tf.estimator.export.PredictOutput(export_output)}
return tf.estimator.EstimatorSpec(mode=mode,
predictions=prediction_dict,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
export_outputs=export_outputs,
scaffold=scaffold)
def configure_learning_rate(decay_steps, global_step):
"""Configures the learning rate.
Modified from:
https://github.com/tensorflow/models/blob/master/research/slim/
train_image_classifier.py
Args:
decay_steps: The step to decay learning rate.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
"""
if FLAGS.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif FLAGS.learning_rate_decay_type == 'fixed':
return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
elif FLAGS.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
else:
raise ValueError('learning_rate_decay_type [%s] was not recognized' %
FLAGS.learning_rate_decay_type)
def init_variables_from_checkpoint(checkpoint_exclude_scopes=None):
"""Variable initialization form a given checkpoint path.
Modified from:
https://github.com/tensorflow/models/blob/master/research/
object_detection/model_lib.py
Note that the init_fn is only run when initializing the model during the
very first global step.
Args:
checkpoint_exclude_scopes: Comma-separated list of scopes of variables
to exclude when restoring from a checkpoint.
"""
exclude_patterns = None
if checkpoint_exclude_scopes:
exclude_patterns = [scope.strip() for scope in
checkpoint_exclude_scopes.split(',')]
variables_to_restore = tf.global_variables()
variables_to_restore.append(slim.get_or_create_global_step())
variables_to_init = tf.contrib.framework.filter_variables(
variables_to_restore, exclude_patterns=exclude_patterns)
variables_to_init_dict = {var.op.name: var for var in variables_to_init}
available_var_map = get_variables_available_in_checkpoint(
variables_to_init_dict, FLAGS.checkpoint_path,
include_global_step=False)
tf.train.init_from_checkpoint(FLAGS.checkpoint_path, available_var_map)
def get_variables_available_in_checkpoint(variables,
checkpoint_path,
include_global_step=True):
"""Returns the subset of variables in the checkpoint.
Inspects given checkpoint and returns the subset of variables that are
available in it.
Args:
variables: A dictionary of variables to find in checkpoint.
checkpoint_path: Path to the checkpoint to restore variables from.
include_global_step: Whether to include `global_step` variable, if it
exists. Default True.
Returns:
A dictionary of variables.
Raises:
ValueError: If `variables` is not a dict.
"""
if not isinstance(variables, dict):
raise ValueError('`variables` is expected to be a dict.')
# Available variables
ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
if not include_global_step:
ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
vars_in_ckpt = {}
for variable_name, variable in sorted(variables.items()):
if variable_name in ckpt_vars_to_shape_map:
if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
vars_in_ckpt[variable_name] = variable
else:
logging.warning('Variable [%s] is avaible in checkpoint, but '
'has an incompatible shape with model '
'variable. Checkpoint shape: [%s], model '
'variable shape: [%s]. This variable will not '
'be initialized from the checkpoint.',
variable_name,
ckpt_vars_to_shape_map[variable_name],
variable.shape.as_list())
else:
logging.warning('Variable [%s] is not available in checkpoint',
variable_name)
return vars_in_ckpt
def main(_):
# Specify which gpu to be used
os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_indices
estimator = tf.estimator.Estimator(model_fn=create_model_fn,
model_dir=FLAGS.model_dir)
train_input_fn = create_input_fn([FLAGS.train_record_path],
batch_size=FLAGS.batch_size)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
max_steps=FLAGS.num_steps)
eval_input_fn = create_input_fn([FLAGS.val_record_path],
batch_size=FLAGS.batch_size,
num_epochs=1)
predict_input_fn = create_predict_input_fn()
eval_exporter = tf.estimator.FinalExporter(
name='servo', serving_input_receiver_fn=predict_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=None,
exporters=eval_exporter)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if __name__ == '__main__':
tf.app.run()
【說(shuō)明】
????????1.如果訓(xùn)練時(shí)訓(xùn)練數(shù)據(jù)被分別寫(xiě)到 多個(gè) tfrecord 里,則修改 main 函數(shù)里的
train_input_fn = create_input_fn([FLAGS.train_record_path],
batch_size=FLAGS.batch_size)
以及
eval_input_fn = create_input_fn([FLAGS.val_record_path],
batch_size=FLAGS.batch_size,
num_epochs=1)
將 [FLAGS.train_record_path] 以及 [FLAGS.val_record_path] 替換成你的 tfrecord 文件路徑列表。
????????2.預(yù)訓(xùn)練路徑由 FLAGS.checkpoint_path 指定,模型導(dǎo)入部分請(qǐng)參考函數(shù)init_variables_from_checkpoint。如果不傳入預(yù)訓(xùn)練模型,請(qǐng)將 FLAGS.checkpoint_path 設(shè)置為 None 即可。
????????3.定義訓(xùn)練的輸入函數(shù)時(shí)還有另一種方式(這種方式來(lái)自于以前的文章):
flags.DEFINE_integer('num_train_samples', 50000, 'Number of samples.')
def get_record_dataset(record_path, reader=None, num_samples=50000,
num_classes=10):
"""Get a tensorflow record file.
Args:
"""
if not reader:
reader = tf.TFRecordReader
decoder = get_decoder()
labels_to_names = None
items_to_descriptions = {
'image': 'An image.',
'label': 'A single integer.'}
return slim.dataset.Dataset(
data_sources=record_path,
reader=reader,
decoder=decoder,
num_samples=num_samples,
num_classes=num_classes,
items_to_descriptions=items_to_descriptions,
labels_to_names=labels_to_names)
def create_train_input_fn(record_path, batch_size=64,
num_samples=50000, num_classes=2):
"""Creates a train `input` function for `Estimator`.
Returns:
`input_fn` for `Estimator` in TRAIN mode.
"""
def _train_input_fn():
dataset = get_record_dataset(record_path,
num_samples=num_samples,
num_classes=num_classes)
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
image, label = data_provider.get(['image', 'label'])
image = transform_data(image)
inputs, labels = tf.train.batch([image, label],
batch_size=batch_size,
allow_smaller_final_batch=True)
return {'image': inputs}, labels
return _train_input_fn
使用時(shí),直接由
train_input_fn = create_train_input_fn(FLAGS.train_record_path,
batch_size=FLAGS.batch_size,
num_samples=FLAGS.num_samples,
num_classes=FLAGS.num_classes)
指定。這種方式的缺陷是必須要知道 tfrecord 里面的樣本數(shù)量,然后通過(guò)
flags.DEFINE_integer('num_train_samples', 50000, 'Number of samples.')
設(shè)置。
三、訓(xùn)練過(guò)程
1.生成 TFRecord
????????通過(guò)運(yùn)行:
python3 generate_tfrecord.py --images_dir path/to/images
生成(請(qǐng)將 path/to/images 替換成你的貓狗數(shù)據(jù)訓(xùn)練集文件夾的具體路徑)。其它參數(shù),如 train_annotation_path 等可試情況指定,如果你都使用默認(rèn)值,則會(huì)在當(dāng)前路徑下生成一個(gè)叫 datasets 的文件夾,里面包含了訓(xùn)練和驗(yàn)證用的 tfrecord:train.record 和 val.record。
2.啟動(dòng)訓(xùn)練
????????通過(guò)運(yùn)行
python3 train.py --train_record_path path/to/train.record
--val_record_path path/to/val.record
--checkpoint_path path/to/resnet_v1_50.ckpt
--model_dir path/to/directory/to/saved/trained/models
啟動(dòng)。其中 train_record_path 和 val_record_path 分別指向訓(xùn)練和驗(yàn)證的 tfrecord 文件路徑,如果你運(yùn)行 generate_tfrecord.py 使用了默認(rèn)路徑,則這里不需要額外指定,直接使用默認(rèn)值即可。checkpoint_path 指定預(yù)訓(xùn)練模型 resnet_v1_50.ckpt 文件路徑,如果不使用預(yù)訓(xùn)練模型,則直接省略這個(gè)參數(shù),使用默認(rèn)值 None。model_dir 指定模型保存的路徑,可以使用默認(rèn)值:當(dāng)前路徑下自動(dòng)建立的文件夾 training。
????????如果你訓(xùn)練和驗(yàn)證的 tfrecord 都有 多個(gè),請(qǐng)參考 二、模型訓(xùn)練 的 說(shuō)明1。
3.訓(xùn)練曲線
????????通過(guò)運(yùn)行
tensorboard --logdir path/to/model_dir
監(jiān)督訓(xùn)練的損失和正確率曲線。其中參數(shù) logdir 填寫(xiě)你的模型保存路徑,即 2 中的 model_dir。
????????比如,我訓(xùn)練的曲線如下(其中:藍(lán)線表示驗(yàn)證集上的評(píng)估結(jié)果,黃線是訓(xùn)練集上的評(píng)估結(jié)果):


4.模型導(dǎo)出
????????運(yùn)行
python3 export_inference_graph.py \
--trained_checkpoint_prefix Path/to/model.ckpt-xxx \
--output_directory Path/to/exported_pb_file_directory
將 .ckpt 文件轉(zhuǎn)化為 .pb 文件。如果你使用默認(rèn)的訓(xùn)練步數(shù):num_steps=5000,則參數(shù) trained_checkpoint_prefix 填寫(xiě) model.ckpt-5000 的路徑,比如 ./training/model.ckpt-5000。參數(shù) output_directory 填寫(xiě)導(dǎo)出的 .pb 文件保存的文件夾路徑,如 ./training/frozen_inference_graph_pb,該文件內(nèi)生成的 frozen_inferece_graph.pb 就是最后要調(diào)用的模型。
5.模型使用
????????請(qǐng)參考 predict.py 文件。
【說(shuō)明】
????????如果你使用貓狗數(shù)據(jù)集,請(qǐng)直接下載該數(shù)據(jù)集,不要修改該數(shù)據(jù)集的每張圖片的名字,然后在生成 tfrecord 的時(shí)候指定 images_dir 到 xxx/xxx/train 文件夾。如果你使用其它數(shù)據(jù)集,你需要修改 data_provide.py 文件的 provide 函數(shù),該函數(shù)需要返回一個(gè)字典,字典中的 鍵值對(duì) 是:圖片路徑:類(lèi)標(biāo)號(hào),比如
{'E:/xxx/train_images/1.jpg': 0,
'E:/xxx/train_images/2.jpg': 1,
...,
'E:/xxx/train_images/10000.jpg': 10}