1

tensorflow を使用して GAN をトレーニングし、ジェネレーターとディスクリミネーターを tensorflow_hub モジュールとしてエクスポートしたいと考えています。
そのために:
- GAN アーキテクチャを tensorflow で定義します
- トレーニングしてチェックポイントを保存します-
次のようなさまざまなタグで module_spec を作成します:
(set(), {'batch_size': 8, 'model': 'gen'})
({'bs8', 'gen'}, {'batch_size': 8, 'model': 'gen'})
({'bs8', 'disc'}, {'batch_size': 8, 'model': 'disc'})
- トレーニング中に保存した checkpoint_path を使用して、tf_hub_path で module_spec でエクスポートします

次に、次のコマンドでジェネレーターをロードできます。

hub.Module(tf_hub_path, tags={"gen", "bs8"})

しかし、同様のコマンドを使用してディスクリミネーターをロードしようとすると:

hub.Module(tf_hub_path, tags={"disc", "bs8"})

エラーが発生しました:

ValueError: Tensor discriminator/linear/bias is not found in b'/tf_hub/variables/variables' checkpoint {'generator/fc_noise/kernel': [2, 48], 'generator/fc_noise/bias': [48]}

したがって、ディスクリミネーターに存在する変数は、ディスク上のモジュールに保存されていないと結論付けました。私が想像したさまざまなエラーの原因を確認しました。

  • モジュール仕様が正しく定義されていること。そのために、モデルをトレーニングし、モジュール スペックを作成し、その module_spec からモジュールを直接ロードすることにしました。これは、ジェネレーターとディスクリミネーターでうまく機能しました。次に、私のmodule_specが正しいと仮定しました
  • 次に、チェックポイントがすべての変数をグラフに正しく保存しているかどうか疑問に思っていました。

    checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
    inspect_list = tf.train.list_variables(checkpoint_path)
    print(inspect_list)
    
    [('disc_step_1/beta1_power', []),
    ('disc_step_1/beta2_power', []),
    ('discriminator/linear/bias', [1]),
    ('discriminator/linear/bias/d_opt', [1]),
    ('discriminator/linear/bias/d_opt_1', [1]),
    ('discriminator/linear/kernel', [3, 1]),
    ('discriminator/linear/kernel/d_opt', [3, 1]),
    ('discriminator/linear/kernel/d_opt_1', [3, 1]),
    ('gen_step/beta1_power', []),
    ('gen_step/beta2_power', []),
    ('generator/fc_noise/bias', [48]),
    ('generator/fc_noise/bias/g_opt', [48]),
    ('generator/fc_noise/bias/g_opt_1', [48]),
    ('generator/fc_noise/kernel', [2, 48]),
    ('generator/fc_noise/kernel/g_opt', [2, 48]),
    ('generator/fc_noise/kernel/g_opt_1', [2, 48]),
    ('global_step', []),
    ('global_step_disc', [])]
    

    したがって、すべての変数がチェックポイント内に正しく保存されていることがわかりました。ジェネレーターに関連する 2 つの変数のみが、ディスク上の tf ハブ モジュールに正しくエクスポートされました。

最後に、私のエラーは次のものから来ていると思います:

module_spec.export(tf_hub_path, checkpoint_path=checkpoint_path)

checkpoint_path から変数をエクスポートするために、タグ「gen」のみが考慮されます。また、変数の名前が module.variable_map とチェックポイント パスのリスト変数の間で一致していることも確認しました。タグ「disc」を持つモジュールの変数マップは次のとおりです。

print(module.variable_map)
{'discriminator/linear/bias': <tf.Variable 'module_8/discriminator/linear/bias:0' shape=(1,) dtype=float32>, 'discriminator/linear/kernel': <tf.Variable 'module_8/discriminator/linear/kernel:0' shape=(3, 1) dtype=float32>}

私は持っている

  • テンソルフロー: 1.13.1
  • tensorflow_hub: 0​​.4.0
  • パイソン: 3.5.2

ご協力いただきありがとうございます

4

1 に答える 1

1

これを行う最もクリーンな方法ではないと思いますが、この問題を処理する方法を見つけました。

次のコード行は、タグなしで hub.Module を呼び出すときに、デフォルトでモジュールを定義します。

(set(), {'batch_size': 8, 'model': 'gen'})

実際、この一連のパラメーターが、module_spec.export を通じてエクスポートされるグラフを定義していることに気付きました。モジュールをインポートするときにジェネレーターの変数にアクセスできたのに、ディスクリミネーターの変数にはアクセスできなかった理由を説明しています。
したがって、デフォルトでこの一連のパラメーターを使用することにしました。

(set(), {'batch_size': 8, 'model': 'both'})

そして、hub.create_module_spec によって呼び出される _module_fn メソッドで、ジェネレーターとディスクリミネーターの両方の入力 (およびそれぞれの出力) をモデルの入力 (それぞれの出力) として定義しました。したがって、module_spec をエクスポートすると、グラフのすべての変数にアクセスできます。

于 2019-06-26T09:55:45.480 に答える