Tensorflow針對(duì)不定尺寸的圖片讀寫tfrecord文件總結(jié)

介紹

最近在讀取tfrecord時(shí),遇到了關(guān)于tensorf shape的問題。

我們需要知道,大多數(shù)情況下圖片進(jìn)行encode編碼保存在tfrecord時(shí) 是一個(gè)一維張量,shape為(1,)。 而在輸入神經(jīng)網(wǎng)絡(luò)之前,我們必須要將這個(gè)圖片張量reshape成一個(gè)合乎網(wǎng)絡(luò)結(jié)構(gòu)需求的三維張量。
在針對(duì)這樣的需求時(shí),我們會(huì)發(fā)現(xiàn),大部分同學(xué)會(huì)選擇在生成tfrecord前就定義好網(wǎng)絡(luò)的輸入shape,例如[224,224,3], 然后將所有的圖片先reshape成這個(gè)大小,接著存儲(chǔ)在tfrecord中。
這種方式的優(yōu)點(diǎn)在于提前完成的reshape,避免了后續(xù)很多的shape uncompatible 的問題,以及后續(xù)訓(xùn)練中不用再對(duì)圖片進(jìn)行reshape,加快了訓(xùn)練速度。
缺點(diǎn)在于,限制了網(wǎng)絡(luò)輸入尺寸的定義。每修改一次神經(jīng)網(wǎng)絡(luò)的輸入shape。

當(dāng)我們需要從存儲(chǔ)著不定尺寸圖片的tfrecord讀取數(shù)據(jù)時(shí), 我們是無(wú)法直接將圖片reshape成指定的網(wǎng)絡(luò)結(jié)構(gòu)輸入尺寸的。例如圖片大小 [667,1085,3]。顯然,我們無(wú)法直接將其reshape成 [224,224,3]的。那么我們?cè)撊绾翁幚砟兀?/p>

按照思路,我們應(yīng)該先將圖片的一維tensor 轉(zhuǎn)換成三維tensor, 然后再利用 tf.image庫(kù)中不同的reshape 操作,將三維圖片tensor轉(zhuǎn)換為需要的 tensor大小。

按照這種思路,在這里,我總結(jié)了兩種讀寫tfrecord的方式,并對(duì)這兩種方式的不同點(diǎn),尤其是容易導(dǎo)致bug的地方進(jìn)行了整理。

第一種: 利用slim.dataset.Dataset讀寫tfrecord文件,這種方式常見于利用slim庫(kù)進(jìn)行目標(biāo)檢測(cè)等網(wǎng)絡(luò)的實(shí)現(xiàn)過(guò)程中。
第二種:tf.parse_single_example 是更為常見的一種方式

利用slim.dataset.Dataset讀寫tfrecord文件

利用這個(gè)這個(gè)接口讀寫tfrecord非常的方便。它的神奇之處在于,
它不需要圖片寬高的信息,只需要其二進(jìn)制string tensor。 這個(gè)接口會(huì)自動(dòng)返回一個(gè)三維圖片tensor。 在此基礎(chǔ)上,我們可以很方便的對(duì)其進(jìn)行reshape,然后輸入神經(jīng)網(wǎng)絡(luò)。
具體步驟如下:
在生成tfrecord文件時(shí),我們需要先定義 tf_example的寫入格式,然后在將圖片文件依據(jù)這個(gè)寫入格式,生成tfrecord文件

  • 定義 tf_example的寫入特征
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def create_tf_example(image_path, label, resize_size=None):
    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)

    # 對(duì)于可能存在RGBA 4通道的圖片進(jìn)行處理
    image,is_process = process_image_channels(image)

    # 如有必要,那么就在生成tfrecord時(shí)即進(jìn)行resize
    width, height = image.size
    if resize_size is not None:
        if width > height:
            width = int(width * resize_size / height)
            height = resize_size
        else:
            width = resize_size
            height = int(height * resize_size / width)
        image = image.resize((width, height), Image.ANTIALIAS)
    # update encode_jpg
    if is_process or resize_size is not None:
        bytes_io = io.BytesIO()
        image.save(bytes_io, format='JPEG')
        encoded_jpg = bytes_io.getvalue()

    tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode()),
            'image/class/label': int64_feature(label),
            'image/height': int64_feature(height),
            'image/width': int64_feature(width)}))
    return tf_example
  • 生成完整的tfrecord文件
    在定義完對(duì)應(yīng)的tf_example 方式后,我們可以遍歷圖片文件,生成完整的tfrecord文件了。
def generate_tfrecord(annotation_dict, output_path, resize_size=None):
    num_valid_tf_example = 0
    writer = tf.python_io.TFRecordWriter(output_path)
    for image_path, label in annotation_dict.items():
        if not tf.gfile.GFile(image_path):
            print('%s does not exist.' % image_path)
            continue
        tf_example = create_tf_example(image_path, label, resize_size)
        if tf_example:
            writer.write(tf_example.SerializeToString())
            num_valid_tf_example += 1

            if num_valid_tf_example % 100 == 0:
                print('Create %d TF_Example.' % num_valid_tf_example)
    writer.close()
    print('Total create TF_Example: %d' % num_valid_tf_example)

對(duì)應(yīng)著,在讀取tfrecord時(shí),slim提供了 slim.dataset.Dataset 的API接口,非常方便對(duì)讀入的tfrecord數(shù)據(jù)進(jìn)行操作。

def get_record_dataset(record_path,
                       reader=None, 
                       num_samples=50000, 
                       num_classes=32):
    """Get a tensorflow record file.
    
    Args:
        
    """
    if not reader:
        reader = tf.TFRecordReader
        
    keys_to_features = {
        'image/encoded': 
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': 
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label': 
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                               dtype=tf.int64))}
        
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                              format_key='image/format'),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)
    
    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer.'}
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)

在返回了slim.dataset.Dataset這個(gè)slim支持的data封裝后, 我們可直接對(duì)返回的圖片數(shù)據(jù)進(jìn)行reshape,保證這個(gè)圖片張量的shape與網(wǎng)絡(luò)結(jié)構(gòu)的輸入層shape一致。

   dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples, 
                                 num_classes=FLAGS.num_classes)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    
    # 輸出當(dāng)前tensor的靜態(tài)shape 和動(dòng)態(tài)shape,與另一種讀取方式進(jìn)行對(duì)比
    print("----------tf.shape(image): ",tf.shape(image))
    print("----------image.get_shape(): ",image.get_shape())
    image = _fixed_sides_resize(image, output_height=368, output_width=368)
        
    inputs, labels = tf.train.batch([image, label],
                                    batch_size=FLAGS.batch_size,
                                    #capacity=5*FLAGS.batch_size,
                                    allow_smaller_final_batch=True)

其中,對(duì)三維圖片張量進(jìn)行reshape的代碼如下

def _fixed_sides_resize(image, output_height, output_width):
    """Resize images by fixed sides.
    
    Args:
        image: A 3-D image `Tensor`.
        output_height: The height of the image after preprocessing.
        output_width: The width of the image after preprocessing.

    Returns:
        resized_image: A 3-D tensor containing the resized image.
    """
    output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
    output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)

    image = tf.expand_dims(image, 0)
    resized_image = tf.image.resize_nearest_neighbor(
        image, [output_height, output_width], align_corners=False)
    resized_image = tf.squeeze(resized_image)
    resized_image.set_shape([None, None, 3])
    return resized_image

完成了這幾步之后,我們就可以利用image 和 label 進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練了。

利用tf.parse_single_example 讀寫tfrecord文件

這種方式我們需要自己手動(dòng)將一維的圖片tensor,先還原成三維圖片tensor。 因?yàn)槊恳粡垐D片的shape不相同。那么我們需要將圖片的shape也存入tfrecord文件中。當(dāng)我們從tfrecord文件中讀取時(shí),我們先利用tf.reshape將一維圖片張量還原成三維圖片張量,再reshape規(guī)定的網(wǎng)絡(luò)輸入尺寸。

  • 照例,此處的重點(diǎn)在于tf_example的構(gòu)建。在這一部分,我將圖片的shape作為一個(gè)feature,也存入了tfrecord里面。 那么,在對(duì)張量的還原時(shí),我們可以利用這個(gè)三維的shape tensor,
def create_tf_example(image_path, label, resize_size=None):
    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    # 對(duì)于RGBA 4通道的圖片進(jìn)行處理
    image,is_process = process_image_channels(image)

    # Resize
    width, height = image.size
    if resize_size is not None:
        if width > height:
            width = int(width * resize_size / height)
            height = resize_size
        else:
            width = resize_size
            height = int(height * resize_size / width)
        image = image.resize((width, height), Image.ANTIALIAS)
    
    img_array = np.asarray(image)
    shape = img_array.shape
    byte_image = image.tobytes()
    
    tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image': bytes_feature(byte_image),
            'label': int64_feature(label),
            'img_shape': int64_list_feature(shape)}))
    return tf_example
  • 在完成這個(gè)后,我們?nèi)耘f可以使用上述提及的generate_tfrecord 函數(shù)來(lái)生成對(duì)應(yīng)的tfrecord

  • 那么,對(duì)應(yīng)這種方式生成的tfrecord文件,我們?cè)撊绾巫x取呢?
    在這里,我給出對(duì)應(yīng)的parse_example函數(shù)就足以了。

def parse(serialized):
    # Define a dict with the data-names and types we expect to
    # find in the TFRecords file.
    # It is a bit awkward that this needs to be specified again,
    # because it could have been written in the header of the
    # TFRecords file instead.

    features = {
        'image':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'label':
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
                                                                     dtype=tf.int64)),
        'img_shape': 
            tf.FixedLenFeature(shape=(3,), dtype=tf.int64)}

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.parse_single_example(
        serialized=serialized, features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.decode_raw(image_raw, tf.uint8)
    # The type is now uint8 but we need it to be float.
    image = tf.cast(image, tf.float32)
    
    shape = parsed_example['img_shape']
    
    image = tf.reshape(image,shape=shape)
    
    if not (shape[0] == shape[1] == default_img_size):
        image = _fixed_sides_resize(image,default_img_size,default_img_size)
    
    image.set_shape([default_img_size,default_img_size,3])
    label = parsed_example['label']
    # The image and label are now correct TensorFlow types.
    return image, label

在這里,讀寫tfrecord的重要流程就已經(jīng)展現(xiàn)好了。

對(duì)比

這兩種方式有一個(gè)比較重要的區(qū)別,那就是制作tfrecord時(shí)存儲(chǔ)的圖片信息不同。
使用slim api時(shí) 我們制作tfrecord 時(shí),相關(guān)代碼為

    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()

當(dāng)我們使用第二種方式時(shí),制作tfrecord時(shí)存儲(chǔ)的圖片信息的相關(guān)代碼如下所示。

image = Image.open(img_dir)
byte_image = image.tobytes()

第一種方式保存的圖片信息,其字節(jié)數(shù)不等于圖片的height, width, channel的乘積。 所以不能用 第二種的方式去讀取這種方式存儲(chǔ)的tfrecord。 會(huì)出現(xiàn) reshape時(shí) 維度不對(duì)的錯(cuò)誤。 當(dāng)然,使用slim.dataset.Dataset 則不需要考慮這個(gè)問題了。 網(wǎng)絡(luò)上使用slim.dataset.Dataset 來(lái)加載tfrecord的方式,都是使用第一種方式存儲(chǔ)的tfrecord數(shù)據(jù)。

第二種方式,其存儲(chǔ)的圖片字節(jié)大小等于圖片的height, width, channel的乘積。所以它可以直接用tf.reshape直接將原圖矩陣還原回來(lái),然后再進(jìn)行下一步的reshape操作。

總結(jié)

之所以寫這篇文章,是因?yàn)榫W(wǎng)絡(luò)上針對(duì)不定尺寸圖片tfrecord讀取的解決方案不是很完善。
例如 https://stackoverflow.com/questions/40258943/using-height-width-information-stored-in-a-tfrecords-file-to-set-shape-of-a-ten
將height, width,channel 分別存入tfrecord,然后按照提問者描述這樣是不成功的。
再例如https://stackoverflow.com/questions/35028173/how-to-read-images-with-different-size-in-a-tfrecord-file 提供的解決方案

image_rows = tf.cast(features['rows'], tf.int32)
image_cols = tf.cast(features['cols'], tf.int32)
image_data = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image_data, tf.pack([image_rows, image_cols, 3]))

這種方式在tf.reshape階段會(huì)報(bào)錯(cuò),因?yàn)槲覀儫o(wú)法將 兩個(gè)tensor和一個(gè)int數(shù)值組合起來(lái)。最完善的方式是直接將shape作為一個(gè)整體存入tfrecord中,最終讀取出來(lái)就是一個(gè)張量了。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容