tensorflow を使用して GAN をトレーニングし、ジェネレーターとディスクリミネーターを tensorflow_hub モジュールとしてエクスポートしたいと考えています。
そのために:
- GAN アーキテクチャを tensorflow で定義します
- トレーニングしてチェックポイントを保存します-
次のようなさまざまなタグで module_spec を作成します:
(set(), {'batch_size': 8, 'model': 'gen'})
({'bs8', 'gen'}, {'batch_size': 8, 'model': 'gen'})
({'bs8', 'disc'}, {'batch_size': 8, 'model': 'disc'})
- トレーニング中に保存した checkpoint_path を使用して、tf_hub_path で module_spec でエクスポートします
次に、次のコマンドでジェネレーターをロードできます。
hub.Module(tf_hub_path, tags={"gen", "bs8"})
しかし、同様のコマンドを使用してディスクリミネーターをロードしようとすると:
hub.Module(tf_hub_path, tags={"disc", "bs8"})
エラーが発生しました:
ValueError: Tensor discriminator/linear/bias is not found in b'/tf_hub/variables/variables' checkpoint {'generator/fc_noise/kernel': [2, 48], 'generator/fc_noise/bias': [48]}
したがって、ディスクリミネーターに存在する変数は、ディスク上のモジュールに保存されていないと結論付けました。私が想像したさまざまなエラーの原因を確認しました。
- モジュール仕様が正しく定義されていること。そのために、モデルをトレーニングし、モジュール スペックを作成し、その module_spec からモジュールを直接ロードすることにしました。これは、ジェネレーターとディスクリミネーターでうまく機能しました。次に、私のmodule_specが正しいと仮定しました
次に、チェックポイントがすべての変数をグラフに正しく保存しているかどうか疑問に思っていました。
checkpoint_path = tf.train.latest_checkpoint(self.model_dir) inspect_list = tf.train.list_variables(checkpoint_path) print(inspect_list) [('disc_step_1/beta1_power', []), ('disc_step_1/beta2_power', []), ('discriminator/linear/bias', [1]), ('discriminator/linear/bias/d_opt', [1]), ('discriminator/linear/bias/d_opt_1', [1]), ('discriminator/linear/kernel', [3, 1]), ('discriminator/linear/kernel/d_opt', [3, 1]), ('discriminator/linear/kernel/d_opt_1', [3, 1]), ('gen_step/beta1_power', []), ('gen_step/beta2_power', []), ('generator/fc_noise/bias', [48]), ('generator/fc_noise/bias/g_opt', [48]), ('generator/fc_noise/bias/g_opt_1', [48]), ('generator/fc_noise/kernel', [2, 48]), ('generator/fc_noise/kernel/g_opt', [2, 48]), ('generator/fc_noise/kernel/g_opt_1', [2, 48]), ('global_step', []), ('global_step_disc', [])]
したがって、すべての変数がチェックポイント内に正しく保存されていることがわかりました。ジェネレーターに関連する 2 つの変数のみが、ディスク上の tf ハブ モジュールに正しくエクスポートされました。
最後に、私のエラーは次のものから来ていると思います:
module_spec.export(tf_hub_path, checkpoint_path=checkpoint_path)
checkpoint_path から変数をエクスポートするために、タグ「gen」のみが考慮されます。また、変数の名前が module.variable_map とチェックポイント パスのリスト変数の間で一致していることも確認しました。タグ「disc」を持つモジュールの変数マップは次のとおりです。
print(module.variable_map)
{'discriminator/linear/bias': <tf.Variable 'module_8/discriminator/linear/bias:0' shape=(1,) dtype=float32>, 'discriminator/linear/kernel': <tf.Variable 'module_8/discriminator/linear/kernel:0' shape=(3, 1) dtype=float32>}
私は持っている
- テンソルフロー: 1.13.1
- tensorflow_hub: 0.4.0
- パイソン: 3.5.2
ご協力いただきありがとうございます