35

TFRecords からすべての例を一度に読み取るにはどうすればよいですか?

私は、 fully_connected_readerの例のメソッドでtf.parse_single_example与えられたものと同様のコードを使用して、個々の例を読み取るために使用してきました。ただし、検証データセット全体に対して一度にネットワークを実行したいので、代わりにそれら全体をロードしたいと考えています。read_and_decode

完全にはわかりませんが、ドキュメントでは、TFRecords ファイル全体を一度にロードするtf.parse_example代わりに使用できることが示唆されているようです。tf.parse_single_example私はこれを機能させることができないようです。機能の指定方法に関係していると推測していますが、機能仕様で複数の例があることをどのように述べているかわかりません。

言い換えれば、次のようなものを使用しようとする私の試み:

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_example(serialized_example, features={
    'image_raw': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64),
})

機能していません。機能が一度に複数の例を期待していないためだと思います(ただし、よくわかりません)。[これにより、エラーが発生しますValueError: Shape () must have rank 1]

これは一度にすべてのレコードを読み取る適切な方法ですか? その場合、実際にレコードを読み取るには何を変更する必要がありますか? どうもありがとう!

4

7 に答える 7

24

わかりやすくするために、1 つの .tfrecords ファイルに数千の画像があります。これらは 720 x 720 の rgb png ファイルです。ラベルは 0、1、2、3 のいずれかです。

parse_example も使用してみましたが、うまくいきませんでしたが、このソリューションは parse_single_example で動作します。

欠点は、現在、各 .tf レコードにいくつのアイテムがあるかを知る必要があることです。これはちょっと残念です。より良い方法が見つかったら、回答を更新します。また、.tfrecords ファイル内のレコード数の範囲外に注意してください。最後のレコードを超えてループすると、最初のレコードからやり直されます。

秘訣は、キュー ランナーにコーディネーターを使用させることでした。

画像が正しいことを確認できるように、画像が読み込まれるときに画像を保存するためのコードをここに残しました。

from PIL import Image
import numpy as np
import tensorflow as tf

def read_and_decode(filename_queue):
 reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(
  serialized_example,
  # Defaults are not specified since both keys are required.
  features={
      'image_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
      'height': tf.FixedLenFeature([], tf.int64),
      'width': tf.FixedLenFeature([], tf.int64),
      'depth': tf.FixedLenFeature([], tf.int64)
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 label = tf.cast(features['label'], tf.int32)
 height = tf.cast(features['height'], tf.int32)
 width = tf.cast(features['width'], tf.int32)
 depth = tf.cast(features['depth'], tf.int32)
 return image, label, height, width, depth


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([ FILE ])
   image, label, height, width, depth = read_and_decode(filename_queue)
   image = tf.reshape(image, tf.pack([height, width, 3]))
   image.set_shape([720,720,3])
   init_op = tf.initialize_all_variables()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   for i in range(2053):
     example, l = sess.run([image, label])
     img = Image.fromarray(example, 'RGB')
     img.save( "output/" + str(i) + '-train.png')

     print (example,l)
   coord.request_stop()
   coord.join(threads)

get_all_records('/path/to/train-0.tfrecords')
于 2016-05-15T18:01:36.463 に答える