TensorFlow textsum モデルを実行しようとしています
トレーニングはうまくいきましたが、モデルの作成者から提供されたおもちゃのデータを使用して「デコード」モードを実行しようとすると、次のエラーが発生します。
Traceback (most recent call last):
File "/home/pavel/Sandbox/TensorFlow/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 212, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 30, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "/home/pavel/Sandbox/TensorFlow/textsum/bazel-bin/textsum/seq2seq_attention.runfiles/__main__/textsum/seq2seq_attention.py", line 208, in main
decoder.DecodeLoop()
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/seq2seq_attention_decode.py", line 101, in DecodeLoop
if not self._Decode(self._saver, sess):
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/seq2seq_attention_decode.py", line 140, in _Decode
best_beam = bs.BeamSearch(sess, article_batch_cp, article_lens_cp)[0]
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/beam_search.py", line 113, in BeamSearch
sess, latest_tokens, enc_top_states, states)
File "/home/pavel/Sandbox/TensorFlow/textsum/textsum/seq2seq_attention_model.py", line 283, in decode_topk
feed_dict=feed)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 864, in _run
feed_dict = nest.flatten_dict_items(feed_dict)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/nest.py", line 186, in flatten_dict_items
% (len(flat_i), len(flat_v), flat_i, flat_v))
ValueError: Could not flatten dictionary. Key had 2 elements, but value had 1 elements. Key: [<tf.Tensor 'seq2seq/encoder3/BiRNN/FW/FW/cond_119/Merge_1:0' shape=(8, 256) dtype=float32>, <tf.Tensor 'seq2seq/encoder3/BiRNN/FW/FW/cond_119/Merge_2:0' shape=(8, 256) dtype=float32>], value: [array([[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
...,
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]],
[[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
...,
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227],
[ 2.72009921, -3.38784456, -0.47797373, ..., 0.81595671,
-6.4550662 , 5.68320227]]], dtype=float32)].
デコードを行うために実行するコマンド:
bazel-bin/textsum/seq2seq_attention --mode=decode --article_key=article --abstract_key=abstract --data_path=data/predict --vocab_path=data/vocab --log_root=log_root --decode_dir=log_root/decode --beam_size=8 --truncate_input=True
その原因は何ですか?
CUDA7.5
CUDNN 5.1
テンソルフロー 0.10
更新: GitHub の問題に関するコメントに基づいて、TensorFlow の以前のバージョン: 0.9 をインストールしようとしました: https://github.com/tensorflow/models/issues/417 問題の解決に役立ちました。バージョン 0.10 で動作しない理由はまだわかりません。