私は CTGAN コードを書いていて、分散された方法でトレーニングしたいと考えています。したがって、私はtf.distribute.Strategy.mirroredstrategy()を使用しています。私がフォロー しているtensorflow docs チュートリアルでは、distribute_trainstep() という関数から train_step コードを呼び出し、それを tf.function で装飾する必要があると述べられています。 . そのようです:
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
これは簡単ですが、tf.function で train_step 内をすべてデコレートすると、train_step 内のすべての numpy コードが役に立たなくなります。私は何をすべきか?train_step 内の関数を選択的にラップするだけの代替手段はありますか? それとも、すべてのnumpy操作をtensorflowのものに置き換える必要がありますか?