グラフで行われた計算を、同じことを行うカスタム op に置き換えようとしています。
グラフに定数A
と重み変数W
があるとします。カスタム op を作成して、これら 2 つの入力を取得し、計算全体を実行します (重み更新の最後のステップを除く)。
custom_op_tensor = custom_module.custom_op([A,W])
g_def = tf.get_default_graph().as_graph_def()
input_map = { tensor.name : custom_op_tensor }
train_op, = tf.import_graph_def(g_def, input_map=input_map, return_elements=[train_op])
インポート グラフ定義の後に、2 つW
の があります。1 つは元のグラフ定義からのもので、もう 1 つはインポートされたグラフにあります。train op を実行すると、カスタム op は古いものW
を読み取り、新しいW
ものが更新されます。その結果、勾配降下法は正しいことを行うことができなくなります。
問題は、 custom_op のインスタンス化に入力の重み tensor が必要なことW
です。新しいW
ものは、インポート後にのみ認識されます。また、インポートにはカスタム操作が必要です。この問題をどのように回避しますか?