1

テンソル フロー チュートリアルの 1 つで例を実行するのに問題があります。チュートリアルでは、入力するだけで実行できると書かれていますpython fully_connected_feed.py。これを行うと、入力データを取得できますが、次のように失敗します。

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
Traceback (most recent call last):
  File "fully_connected_feed.py", line 225, in <module>
    tf.app.run()
  File "/Users/me/anaconda/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
    sys.exit(main(sys.argv))
  File "fully_connected_feed.py", line 221, in main
    run_training()
  File "fully_connected_feed.py", line 141, in run_training
    loss = mnist.loss(logits, labels_placeholder)
  File "/Users/me/tftmp/mnist.py", line 96, in loss
    indices = tf.expand_dims(tf.range(batch_size), 1)
TypeError: range() takes at least 2 arguments (1 given)

セッションのセットアップやテンソルの評価に何らかの問題があるため、このエラーが発生したと思います。これは、問題を引き起こしている mnist.py の関数です。

def loss(logits, labels):
  """Calculates the loss from the logits and the labels.

  Args:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size].

  Returns:
    loss: Loss tensor of type float.
  """
  # Convert from sparse integer labels in the range [0, NUM_CLASSSES)
  # to 1-hot dense float vectors (that is we will have batch_size vectors,
  # each with NUM_CLASSES values, all of which are 0.0 except there will
  # be a 1.0 in the entry corresponding to the label).
  batch_size = tf.size(labels)
  labels = tf.expand_dims(labels, 1)
  indices = tf.expand_dims(tf.range(batch_size), 1)
  concated = tf.concat(1, [indices, labels])
  onehot_labels = tf.sparse_to_dense(
      concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
  cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
                                                          name='xentropy')
  loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
  return loss

lossブロック内の関数にすべてのコードを配置すると、with tf.Session():このエラーが発生します。ただし、後で初期化されていない変数について他のエラーが発生するため、セッションのセットアップまたは初期化などで何か大きな問題が発生していると推測しています。テンソルフローが初めてなので、少し途方に暮れています。何か案は?

[注:コードをまったく編集していません。テンソルフローチュートリアルからダウンロードして、指示どおりに実行しようとしましたpython fully_connected_feed.py]

4

3 に答える 3

6

この問題が発生するのは、GitHub の TensorFlow ソースの最新バージョンでは、tf.range()がその引数をより許容するように更新されrange()たためです (以前は 2 つの引数が必要でしたが、現在は Python の組み込み関数と同じセマンティクスを持っていfully_connected_feed.pyます)。これを利用するために更新されました。

ただし、TensorFlow のバイナリ ディストリビューションに対してこのバージョンを実行しようとすると、への変更がtf.range()バイナリ パッケージに組み込まれていないため、このエラーが発生します。

最も簡単な解決策は、古いバージョンのmnist.pyをダウンロードすることです。または、ソースからビルドして最新バージョンのチュートリアルを使用することもできます。

于 2015-11-13T16:21:23.983 に答える