tf.estimator.Estimator train メソッドを使用して、TFRecord データセット全体で cnn をトレーニングするのに苦労しています。
次のように、ループ内でトレインを実行しようとしました。
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=MODEL_FOLDER)
input_fn = generate_input_fn(path, [], batch_size=128,
shuffle=True, num_epochs=None)
while True:
estimator.train(
input_fn=input_fn, steps=1, hooks=[logging_hook])
私のinput_fnは次のようになります。
def generate_input_fn(file_pattern, given_labels, batch_size=1,
num_epochs=None, shuffle=False):
def _input_fn():
print("_input_fn: file pattern: " + file_pattern)
filenames_tensor = tf.train.match_filenames_once(file_pattern)
filename_queue = tf.train.string_input_producer(
filenames_tensor,
num_epochs=num_epochs,
shuffle=shuffle)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/width': tf.FixedLenFeature([], tf.int64),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/class/label': tf.FixedLenFeature([LABELS_SIZE], tf.int64),
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/format': tf.FixedLenFeature([], tf.string),
'image/name': tf.FixedLenFeature([], tf.string)
})
labels = features['image/class/label']
filename = features['image/name']
image = tf.image.decode_jpeg(
features["image/encoded"], channels=IMAGE_CHANNELS)
image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS])
image = tf.image.resize_image_with_crop_or_pad(
image, IMAGE_HEIGHT, IMAGE_WIDTH)
image_batch, batch_labels, filename_batch = tf.train.shuffle_batch(
[image, labels, filename],
batch_size,
num_threads=8,
capacity=5000,
min_after_dequeue=1000
# allow_smaller_final_batch=True
)
# so that the "center" of the image range is roughly 0.
image_batch = tf.to_float(image_batch) / 255
image_batch = (image_batch * 2) - 1
features = {
"image": image_batch,
"filename": filename_batch
}
return features, batch_labels
return _input_fn
私のmodel_fnには次のコードがあります:
logits = tf.Print(logits, [logits], "Logits: ")
features['filename'] = tf.Print(features['filename'], [features['filename']], 'Filename: ')
tf.summary.text('filename', features['filename'])
しかし、model_fn でファイル名を tf.Print すると、すべての実行で同じバッチを取得しているように見えます。これまでのところ、私は次のことを試みました。* リーダーを generate_input_fn の範囲外に移動しようとしましたが、入力テンソルが別のグラフからのものであると表示されています
私が間違っていることについて何か考えはありますか?助けてくれてありがとう!