介紹
最近在讀取tfrecord時(shí),遇到了關(guān)于tensorf shape的問題。
我們需要知道,大多數(shù)情況下圖片進(jìn)行encode編碼保存在tfrecord時(shí) 是一個(gè)一維張量,shape為(1,)。 而在輸入神經(jīng)網(wǎng)絡(luò)之前,我們必須要將這個(gè)圖片張量reshape成一個(gè)合乎網(wǎng)絡(luò)結(jié)構(gòu)需求的三維張量。
在針對(duì)這樣的需求時(shí),我們會(huì)發(fā)現(xiàn),大部分同學(xué)會(huì)選擇在生成tfrecord前就定義好網(wǎng)絡(luò)的輸入shape,例如[224,224,3], 然后將所有的圖片先reshape成這個(gè)大小,接著存儲(chǔ)在tfrecord中。
這種方式的優(yōu)點(diǎn)在于提前完成的reshape,避免了后續(xù)很多的shape uncompatible 的問題,以及后續(xù)訓(xùn)練中不用再對(duì)圖片進(jìn)行reshape,加快了訓(xùn)練速度。
缺點(diǎn)在于,限制了網(wǎng)絡(luò)輸入尺寸的定義。每修改一次神經(jīng)網(wǎng)絡(luò)的輸入shape。
當(dāng)我們需要從存儲(chǔ)著不定尺寸圖片的tfrecord讀取數(shù)據(jù)時(shí), 我們是無(wú)法直接將圖片reshape成指定的網(wǎng)絡(luò)結(jié)構(gòu)輸入尺寸的。例如圖片大小 [667,1085,3]。顯然,我們無(wú)法直接將其reshape成 [224,224,3]的。那么我們?cè)撊绾翁幚砟兀?/p>
按照思路,我們應(yīng)該先將圖片的一維tensor 轉(zhuǎn)換成三維tensor, 然后再利用 tf.image庫(kù)中不同的reshape 操作,將三維圖片tensor轉(zhuǎn)換為需要的 tensor大小。
按照這種思路,在這里,我總結(jié)了兩種讀寫tfrecord的方式,并對(duì)這兩種方式的不同點(diǎn),尤其是容易導(dǎo)致bug的地方進(jìn)行了整理。
第一種: 利用slim.dataset.Dataset讀寫tfrecord文件,這種方式常見于利用slim庫(kù)進(jìn)行目標(biāo)檢測(cè)等網(wǎng)絡(luò)的實(shí)現(xiàn)過(guò)程中。
第二種:tf.parse_single_example 是更為常見的一種方式
利用slim.dataset.Dataset讀寫tfrecord文件
利用這個(gè)這個(gè)接口讀寫tfrecord非常的方便。它的神奇之處在于,
它不需要圖片寬高的信息,只需要其二進(jìn)制string tensor。 這個(gè)接口會(huì)自動(dòng)返回一個(gè)三維圖片tensor。 在此基礎(chǔ)上,我們可以很方便的對(duì)其進(jìn)行reshape,然后輸入神經(jīng)網(wǎng)絡(luò)。
具體步驟如下:
在生成tfrecord文件時(shí),我們需要先定義 tf_example的寫入格式,然后在將圖片文件依據(jù)這個(gè)寫入格式,生成tfrecord文件
- 定義 tf_example的寫入特征
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def create_tf_example(image_path, label, resize_size=None):
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
# 對(duì)于可能存在RGBA 4通道的圖片進(jìn)行處理
image,is_process = process_image_channels(image)
# 如有必要,那么就在生成tfrecord時(shí)即進(jìn)行resize
width, height = image.size
if resize_size is not None:
if width > height:
width = int(width * resize_size / height)
height = resize_size
else:
width = resize_size
height = int(height * resize_size / width)
image = image.resize((width, height), Image.ANTIALIAS)
# update encode_jpg
if is_process or resize_size is not None:
bytes_io = io.BytesIO()
image.save(bytes_io, format='JPEG')
encoded_jpg = bytes_io.getvalue()
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpg'.encode()),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)}))
return tf_example
- 生成完整的tfrecord文件
在定義完對(duì)應(yīng)的tf_example 方式后,我們可以遍歷圖片文件,生成完整的tfrecord文件了。
def generate_tfrecord(annotation_dict, output_path, resize_size=None):
num_valid_tf_example = 0
writer = tf.python_io.TFRecordWriter(output_path)
for image_path, label in annotation_dict.items():
if not tf.gfile.GFile(image_path):
print('%s does not exist.' % image_path)
continue
tf_example = create_tf_example(image_path, label, resize_size)
if tf_example:
writer.write(tf_example.SerializeToString())
num_valid_tf_example += 1
if num_valid_tf_example % 100 == 0:
print('Create %d TF_Example.' % num_valid_tf_example)
writer.close()
print('Total create TF_Example: %d' % num_valid_tf_example)
對(duì)應(yīng)著,在讀取tfrecord時(shí),slim提供了 slim.dataset.Dataset 的API接口,非常方便對(duì)讀入的tfrecord數(shù)據(jù)進(jìn)行操作。
def get_record_dataset(record_path,
reader=None,
num_samples=50000,
num_classes=32):
"""Get a tensorflow record file.
Args:
"""
if not reader:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/class/label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64))}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded',
format_key='image/format'),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
items_to_descriptions = {
'image': 'An image with shape image_shape.',
'label': 'A single integer.'}
return slim.dataset.Dataset(
data_sources=record_path,
reader=reader,
decoder=decoder,
num_samples=num_samples,
num_classes=num_classes,
items_to_descriptions=items_to_descriptions,
labels_to_names=labels_to_names)
在返回了slim.dataset.Dataset這個(gè)slim支持的data封裝后, 我們可直接對(duì)返回的圖片數(shù)據(jù)進(jìn)行reshape,保證這個(gè)圖片張量的shape與網(wǎng)絡(luò)結(jié)構(gòu)的輸入層shape一致。
dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples,
num_classes=FLAGS.num_classes)
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
image, label = data_provider.get(['image', 'label'])
# 輸出當(dāng)前tensor的靜態(tài)shape 和動(dòng)態(tài)shape,與另一種讀取方式進(jìn)行對(duì)比
print("----------tf.shape(image): ",tf.shape(image))
print("----------image.get_shape(): ",image.get_shape())
image = _fixed_sides_resize(image, output_height=368, output_width=368)
inputs, labels = tf.train.batch([image, label],
batch_size=FLAGS.batch_size,
#capacity=5*FLAGS.batch_size,
allow_smaller_final_batch=True)
其中,對(duì)三維圖片張量進(jìn)行reshape的代碼如下
def _fixed_sides_resize(image, output_height, output_width):
"""Resize images by fixed sides.
Args:
image: A 3-D image `Tensor`.
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
image = tf.expand_dims(image, 0)
resized_image = tf.image.resize_nearest_neighbor(
image, [output_height, output_width], align_corners=False)
resized_image = tf.squeeze(resized_image)
resized_image.set_shape([None, None, 3])
return resized_image
完成了這幾步之后,我們就可以利用image 和 label 進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練了。
利用tf.parse_single_example 讀寫tfrecord文件
這種方式我們需要自己手動(dòng)將一維的圖片tensor,先還原成三維圖片tensor。 因?yàn)槊恳粡垐D片的shape不相同。那么我們需要將圖片的shape也存入tfrecord文件中。當(dāng)我們從tfrecord文件中讀取時(shí),我們先利用tf.reshape將一維圖片張量還原成三維圖片張量,再reshape規(guī)定的網(wǎng)絡(luò)輸入尺寸。
- 照例,此處的重點(diǎn)在于tf_example的構(gòu)建。在這一部分,我將圖片的shape作為一個(gè)feature,也存入了tfrecord里面。 那么,在對(duì)張量的還原時(shí),我們可以利用這個(gè)三維的shape tensor,
def create_tf_example(image_path, label, resize_size=None):
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
# 對(duì)于RGBA 4通道的圖片進(jìn)行處理
image,is_process = process_image_channels(image)
# Resize
width, height = image.size
if resize_size is not None:
if width > height:
width = int(width * resize_size / height)
height = resize_size
else:
width = resize_size
height = int(height * resize_size / width)
image = image.resize((width, height), Image.ANTIALIAS)
img_array = np.asarray(image)
shape = img_array.shape
byte_image = image.tobytes()
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image': bytes_feature(byte_image),
'label': int64_feature(label),
'img_shape': int64_list_feature(shape)}))
return tf_example
在完成這個(gè)后,我們?nèi)耘f可以使用上述提及的generate_tfrecord 函數(shù)來(lái)生成對(duì)應(yīng)的tfrecord
那么,對(duì)應(yīng)這種方式生成的tfrecord文件,我們?cè)撊绾巫x取呢?
在這里,我給出對(duì)應(yīng)的parse_example函數(shù)就足以了。
def parse(serialized):
# Define a dict with the data-names and types we expect to
# find in the TFRecords file.
# It is a bit awkward that this needs to be specified again,
# because it could have been written in the header of the
# TFRecords file instead.
features = {
'image':
tf.FixedLenFeature((), tf.string, default_value=''),
'label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64)),
'img_shape':
tf.FixedLenFeature(shape=(3,), dtype=tf.int64)}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.parse_single_example(
serialized=serialized, features=features)
# Get the image as raw bytes.
image_raw = parsed_example['image']
# Decode the raw bytes so it becomes a tensor with type.
image = tf.decode_raw(image_raw, tf.uint8)
# The type is now uint8 but we need it to be float.
image = tf.cast(image, tf.float32)
shape = parsed_example['img_shape']
image = tf.reshape(image,shape=shape)
if not (shape[0] == shape[1] == default_img_size):
image = _fixed_sides_resize(image,default_img_size,default_img_size)
image.set_shape([default_img_size,default_img_size,3])
label = parsed_example['label']
# The image and label are now correct TensorFlow types.
return image, label
在這里,讀寫tfrecord的重要流程就已經(jīng)展現(xiàn)好了。
對(duì)比
這兩種方式有一個(gè)比較重要的區(qū)別,那就是制作tfrecord時(shí)存儲(chǔ)的圖片信息不同。
使用slim api時(shí) 我們制作tfrecord 時(shí),相關(guān)代碼為
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
當(dāng)我們使用第二種方式時(shí),制作tfrecord時(shí)存儲(chǔ)的圖片信息的相關(guān)代碼如下所示。
image = Image.open(img_dir)
byte_image = image.tobytes()
第一種方式保存的圖片信息,其字節(jié)數(shù)不等于圖片的height, width, channel的乘積。 所以不能用 第二種的方式去讀取這種方式存儲(chǔ)的tfrecord。 會(huì)出現(xiàn) reshape時(shí) 維度不對(duì)的錯(cuò)誤。 當(dāng)然,使用slim.dataset.Dataset 則不需要考慮這個(gè)問題了。 網(wǎng)絡(luò)上使用slim.dataset.Dataset 來(lái)加載tfrecord的方式,都是使用第一種方式存儲(chǔ)的tfrecord數(shù)據(jù)。
第二種方式,其存儲(chǔ)的圖片字節(jié)大小等于圖片的height, width, channel的乘積。所以它可以直接用tf.reshape直接將原圖矩陣還原回來(lái),然后再進(jìn)行下一步的reshape操作。
總結(jié)
之所以寫這篇文章,是因?yàn)榫W(wǎng)絡(luò)上針對(duì)不定尺寸圖片tfrecord讀取的解決方案不是很完善。
例如 https://stackoverflow.com/questions/40258943/using-height-width-information-stored-in-a-tfrecords-file-to-set-shape-of-a-ten
將height, width,channel 分別存入tfrecord,然后按照提問者描述這樣是不成功的。
再例如https://stackoverflow.com/questions/35028173/how-to-read-images-with-different-size-in-a-tfrecord-file 提供的解決方案
image_rows = tf.cast(features['rows'], tf.int32)
image_cols = tf.cast(features['cols'], tf.int32)
image_data = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image_data, tf.pack([image_rows, image_cols, 3]))
這種方式在tf.reshape階段會(huì)報(bào)錯(cuò),因?yàn)槲覀儫o(wú)法將 兩個(gè)tensor和一個(gè)int數(shù)值組合起來(lái)。最完善的方式是直接將shape作為一個(gè)整體存入tfrecord中,最終讀取出來(lái)就是一個(gè)張量了。