4

現在、埋め込みを使用したいワンホット エンコーディングがあります。しかし、私が電話するとき

embed=tf.nn.embedding_lookup(embeddings, train_data) 
print(embed.get_shape())

埋め込みデータ形状 (11、32、729、128)

この形状は (11, 32, 128) である必要がありますが、train_data がワンホット エンコードされているため、間違った次元が得られます。

train_data2=tf.matmul(train_data,tf.range(729))

エラーを教えてください:

ValueError: Shape must be rank 2 but is rank 3

助けてください!ありがとう。

4

1 に答える 1

2

あなたの例の小さな修正:

encoding_size = 4
one_hot_batch = tf.constant([[0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0]])
one_hot_indexes = tf.matmul(one_hot_batch, np.array([range(encoding_size)], 
    dtype=np.int32).T)

with tf.Session() as session:
  print one_hot_indexes.eval()

別の方法:

batch_size = 3
one_hot_batch = tf.constant([[0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0]])
one_hot_indexes = tf.where(tf.not_equal(one_hot_batch, 0))
one_hot_indexes = one_hot_indexes[:, 1]
one_hot_indexes = tf.reshape(one_hot_indexes, [batch_size, 1])
with tf.Session() as session:
  print one_hot_indexes.eval()

両方の場合の結果:

[[3]
 [1]
 [0]]
于 2016-11-09T06:25:44.690 に答える