tensorflow數(shù)據(jù)讀入、數(shù)據(jù)加載

先讀下官網(wǎng)的api:

https://www.tensorflow.org/api_guides/python/reading_data

分為三種

  • placeholder,把數(shù)據(jù)feed進(jìn)去,這個(gè)需要自己寫數(shù)據(jù)迭代器和shuffle,還要控制epoch。
  • 讀取文件,tfrecord,csv等
  • 預(yù)加載的文件,都進(jìn)內(nèi)存,小數(shù)據(jù)下使用。
    推薦從文件中讀取,使用tfrecord,讓tf自動(dòng)load和shuf文件、還可以控制epoch。

將訓(xùn)練數(shù)據(jù)轉(zhuǎn)換成tfrecord

def write_tfrecord(writer, char_ids, label_id):
     """
     :param writer: tf record writer
     :param char_ids: list
     :param label_id: int
     """
     example = tf.train.Example(features=tf.train.Features(feature={
         'char_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=char_ids)),
         'label_id': tf.train.Feature(int64_list=tf.train.Int64List(value=label_id))
     }))
     writer.write(example.SerializeToString())

writer = tf.python_io.TFRecordWriter(output_file_name)
for line open(your_files):
    char_ids = ...
    label_id = ...
    write_tfrecord(writer, char_ids, label_id)
writer.close()

讀取tfrecord,并解碼

  • 單純解碼
def test_read_tfrecords():
    filename = "./data/train.tfrecords"
    for serialized_example in tf.python_io.tf_record_iterator(filename):
        example = tf.train.Example()
        example.ParseFromString(serialized_example)
        # traverse the Example format to get data
        x = example.features.feature['char_ids'].int64_list.value
        y = example.features.feature['label_id'].int64_list.value
        # do something

將tfrecord、batch和train聯(lián)系起來(lái)

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'char_ids': tf.FixedLenFeature([max_seq_len], tf.int64),
            'label_id': tf.FixedLenFeature([1], tf.int64)
        })

    char_ids = tf.cast(features['char_ids'], tf.int32)
    # label_ids = tf.one_hot(tf.cast(features['label_ids'], tf.int32)[0], len(data_set.label_dict))
    label_id = tf.cast(features['label_id'], tf.int32)[0]
    return char_ids, label_id

def inputs(filename_list, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filename_list, num_epochs=num_epochs)
    char_ids, label_id = read_and_decode(filename_queue)
    batch_char_ids, batch_label_id = tf.train.shuffle_batch(
        [char_ids, label_id], batch_size=batch_size, num_threads=12,
        capacity=1000 + 3 * batch_size,
        min_after_dequeue=1000)

    return batch_char_ids, batch_label_id

# main session
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
    while not coord.should_stop():
        # Run training steps or whatever
        x_batch, y_batch = inputs(data_files, batch_size, num_epochs=50)
        train_op = ...
        sess.run(train_op)
except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
sess.close()
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容