書き込み機能:
def getRecordData(fileName, outFile):
with tf.io.gfile.GFile(fileName, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
y = []
with open(outFile) as outFile:
# ...populate y....
return {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=encoded_jpg_io)),
'output': tf.train.Feature(float_list=tf.train.FloatList(value=y))
}
tfrecords の解析:
def parseExample(example):
features = {
"image": tf.io.FixedLenFeature([], tf.string),
"output": tf.io.FixedLenFeature([], tf.float32)
}
parsed = tf.io.parse_single_example(example, features=features)
image = tf.image.decode_png(parsed["image"], channels=3)
return image, parsed["output"]
def make_dataset(dir, dtype, dataSetType, parse_fn):
dataset = tf.data.TFRecordDataset(...path...)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.map(parseExample)
dataset = dataset.batch(batch_size=32)
dataset.cache('E:\\trainingcache')
return dataset
画像が正しく読み込まれたかどうかを確認しようとすると
dataset = make_dataset(args.records_dir, 'training', 'tables', parseExample)
for image_features in dataset:
image_raw = image_features['image'].numpy()
display.display(display.Image(data=image_raw))
私は得る:
example_parsing_ops.cc:240 : 無効な引数: キー: 出力。シリアル化された例を解析できません。