0

画像生成内で DCGAN + Reptile を使用するメタ学習クラスを分析しています。

このコードについて 2 つの質問があります。

最初の質問: DCGAN トレーニング中の理由 (74 行目)

training_batch = torch.cat ([real_batch, fake_batch])

実例 (real_batch) と偽例 (fake_batch) で構成される training_batch は作成されますか? 実像と偽像を混ぜて訓練を行うのはなぜですか? 多くの DCGAN を見てきましたが、この方法でトレーニングを行ったことはありません。

2 番目の質問: トレーニング中に normalize_data 関数 (49 行目) と unnormalize_data 関数 (55 行目) が使用されるのはなぜですか?

def normalize_data(data):
    data *= 2
    data -= 1
    return data


def unnormalize_data(data):
    data += 1
    data /= 2
    return data

このプロジェクトでは Mnist データセットを使用しています。CIFAR10 のようなカラー データセットを使用したい場合、それらの正規化を変更する必要がありますか?

4

2 に答える 2

1

GAN のトレーニングには、弁別器に実際の例と偽の例を与えることが含まれます。通常、それらは 2 回に分けて与えられます。デフォルトでは、バッチ次元であるtorch.cat最初の次元 ( ) でテンソルを連結します。dim=0したがって、バッチ サイズが 2 倍になり、前半が実際の画像で、後半が偽の画像になります。

損失を計算するために、前半 (元のバッチ サイズ) が本物として分類され、後半が偽物として分類されるように、ターゲットを調整します。からinitialize_gan:

self.discriminator_targets = torch.tensor([1] * self.batch_size + [-1] * self.batch_size, dtype=torch.float, device=device).view(-1, 1)

画像は [0, 1] の間の float 値で表されます。正規化は、[-1, 1] の間の値を生成するように変更します。GAN は通常、ジェネレーターで tanh を使用するため、偽の画像は [-1, 1] の間の値を持つため、実際の画像は同じ範囲内にある必要があります。 .

これらの画像を表示したい場合は、最初に非正規化する必要があります。つまり、[0, 1] の間の値に変換します。

このプロジェクトでは Mnist データセットを使用しています。CIFAR10 のようなカラー データセットを使用したい場合、それらの正規化を変更する必要がありますか?

いいえ、それらを変更する必要はありません。カラーの画像も [0, 1] の間の値を持ち、3 つのチャンネル (RGB) を表す値が増えるだけです。

于 2020-05-21T17:36:05.127 に答える