わかりやすくするために、1 つの .tfrecords ファイルに数千の画像があります。これらは 720 x 720 の rgb png ファイルです。ラベルは 0、1、2、3 のいずれかです。
parse_example も使用してみましたが、うまくいきませんでしたが、このソリューションは parse_single_example で動作します。
欠点は、現在、各 .tf レコードにいくつのアイテムがあるかを知る必要があることです。これはちょっと残念です。より良い方法が見つかったら、回答を更新します。また、.tfrecords ファイル内のレコード数の範囲外に注意してください。最後のレコードを超えてループすると、最初のレコードからやり直されます。
秘訣は、キュー ランナーにコーディネーターを使用させることでした。
画像が正しいことを確認できるように、画像が読み込まれるときに画像を保存するためのコードをここに残しました。
from PIL import Image
import numpy as np
import tensorflow as tf
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
return image, label, height, width, depth
def get_all_records(FILE):
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer([ FILE ])
image, label, height, width, depth = read_and_decode(filename_queue)
image = tf.reshape(image, tf.pack([height, width, 3]))
image.set_shape([720,720,3])
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(2053):
example, l = sess.run([image, label])
img = Image.fromarray(example, 'RGB')
img.save( "output/" + str(i) + '-train.png')
print (example,l)
coord.request_stop()
coord.join(threads)
get_all_records('/path/to/train-0.tfrecords')