23

Tensorflow v1.3 の Dataset API で遊んでいます。それは素晴らしい。here で説明されているように、関数を使用してデータセットをマップすることができます。追加の引数を持つ関数を渡す方法を知りたいです。たとえば、次のようになりarg1ます。

def _parse_function(example_proto, arg1):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

もちろん、

dataset = dataset.map(_parse_function)

を渡す方法がないため、機能しませんarg1

4

3 に答える 3

39

ラムダ式を使用して、引数を渡したい関数をラップする例を次に示します。

import tensorflow as tf
def fun(x, arg):
    return x * arg

my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))

上記では、提供される関数のシグネチャは、mapデータセットの内容と一致する必要があります。したがって、それに一致するようにラムダ式を作成する必要があります。ここでは単純です。データセットに含まれる要素は 1 つだけで、x0 から 4 の範囲の要素が含まれます。

必要に応じて、データセットの外部から任意の数の外部引数を渡すことができます:ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3)など。

上記が機能することを確認するために、マッピングが実際に各データセット要素を 2 倍することを確認できます。

iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)

    while True:
      try:
        print(sess.run(next_x))
      except tf.errors.OutOfRangeError:
        break

出力:

0
2
4
6
8
于 2018-02-01T19:40:08.630 に答える
1

もう 1 つの解決策は、クラス ラッパーを使用することです。次のコードでは、パラメーターshapeを parse 関数に渡しています。

class MyDataSets:

    def __init__(self, shape):
        self.shape = shape

    def parse_sample(self.sample):
        features = { ... }
        f = tf.parse_example([example], features=features)

        image_raw = tf.decode_raw(f['image_raw'], tf.uint8)
        image = image.reshape(image_raw, self.shape)

        label = tf.cast(f['label'], tf.int32)

        return image, label

    def init(self):
        ds = tf.data.TFRecordDataSets(...)
        ds = ds.map(self.parse_sample)
        ...
        return ds.make_initializable_iterator()
于 2019-03-26T00:43:34.227 に答える