2

TPUEstimator を使用してブール値のマスキング操作を実装する必要があります。tf.boolean_mask は実装されていません。回避策はありますか?

次のコードは、CPU と GPU で私の目的に完全に適合します。

  all_out = model.get_sequence_output()
  P = tf.boolean_mask(all_out, P_mask)

all_out は形状 [?, 128, 768] のテンソルです

P_mask は形状 [?, 128] であり、2 番目の次元は、抽出する目的のテンソルを表すためにワンホット エンコードされます。

P の望ましい形状は [?,768] です。

TPUEstimator を使用して TPU でこれを実行すると、次のエラー メッセージが表示されます。

Compilation failure: Detected unsupported operations when trying to
compile graph _functionalize_body_1[] on XLA_TPU_JIT: Where (No 
registered 'Where' OpKernel for XLA_TPU_JIT devices compatible with node
node boolean_mask/Where
4

1 に答える 1