2

tf.estimator.Estimator train メソッドを使用して、TFRecord データセット全体で cnn をトレーニングするのに苦労しています。

次のように、ループ内でトレインを実行しようとしました。

estimator = tf.estimator.Estimator(
    model_fn=model_fn, model_dir=MODEL_FOLDER)
input_fn = generate_input_fn(path, [], batch_size=128,
                             shuffle=True, num_epochs=None)
while True:
    estimator.train(
        input_fn=input_fn, steps=1, hooks=[logging_hook])

私のinput_fnは次のようになります。

def generate_input_fn(file_pattern, given_labels, batch_size=1,
                      num_epochs=None, shuffle=False):
    def _input_fn():
        print("_input_fn: file pattern: " + file_pattern)

        filenames_tensor = tf.train.match_filenames_once(file_pattern)
        filename_queue = tf.train.string_input_producer(
            filenames_tensor,
            num_epochs=num_epochs,
            shuffle=shuffle)

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)

        features = tf.parse_single_example(
            serialized_example,
            features={
                'image/width': tf.FixedLenFeature([], tf.int64),
                'image/height': tf.FixedLenFeature([], tf.int64),
                'image/class/label': tf.FixedLenFeature([LABELS_SIZE], tf.int64),
                'image/encoded': tf.FixedLenFeature([], tf.string),
                'image/format': tf.FixedLenFeature([], tf.string),
                'image/name': tf.FixedLenFeature([], tf.string)
            })

        labels = features['image/class/label']
        filename = features['image/name']

        image = tf.image.decode_jpeg(
            features["image/encoded"], channels=IMAGE_CHANNELS)
        image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS])

        image = tf.image.resize_image_with_crop_or_pad(
            image, IMAGE_HEIGHT, IMAGE_WIDTH)

        image_batch, batch_labels, filename_batch = tf.train.shuffle_batch(
            [image, labels, filename],
            batch_size,
            num_threads=8,
            capacity=5000,
            min_after_dequeue=1000
            # allow_smaller_final_batch=True
        )

        # so that the "center" of the image range is roughly 0.
        image_batch = tf.to_float(image_batch) / 255
        image_batch = (image_batch * 2) - 1

        features = {
            "image": image_batch,
            "filename": filename_batch
        }

        return features, batch_labels
    return _input_fn

私のmodel_fnには次のコードがあります:

logits = tf.Print(logits, [logits], "Logits: ")
features['filename'] = tf.Print(features['filename'], [features['filename']], 'Filename: ')
tf.summary.text('filename', features['filename'])

しかし、model_fn でファイル名を tf.Print すると、すべての実行で同じバッチを取得しているように見えます。これまでのところ、私は次のことを試みました。* リーダーを generate_input_fn の範囲外に移動しようとしましたが、入力テンソルが別のグラフからのものであると表示されています

私が間違っていることについて何か考えはありますか?助けてくれてありがとう!

4

0 に答える 0