私はGANを構築していますが、再利用を使用してディスクリミネーターを2回呼び出し始めたとき、GANが発散し始めました。最初に次のように識別子を作成しました。
def discriminator(self, x_past, x_future, gen_future):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
with tf.variable_scope("disc") as disc:
gen_future = tf.concat([gen_future, x_past], 2)
x_future = tf.concat([x_future, x_past], 2)
x_in = tf.concat([gen_future, x_future], 0)
conv1 = tf.layers.conv1d(inputs=x_in, filters=20, kernel_size=3, strides=1,
padding='same', activation=tf.nn.relu)
max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')
conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=3, kernel_size=2, strides=1,
padding='same', activation=tf.nn.relu)
max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')
# Flatten and add dropout
flat = tf.reshape(max_pool_2, (-1, 9))
flat = tf.nn.dropout(flat, keep_prob=self.keep_prob)
# Predictions
logits = tf.layers.dense(flat, 2)
y_true = logits[:self.batch_size]
y_gen = logits[self.batch_size:]
return y_true, y_gen
そして、私はそれを次のように呼んでいました:
y_true, y_gen = self.discriminator(x_past, x_future, gen_future)
GANを適切にトレーニングできました。ここで、実際のデータと偽のデータを一度に送信しなくても呼び出すことができるように、reuse を使用する必要があります。私はそれを次のように変更しました:
def discriminator(self, x_past, x_future, reuse=False):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
with tf.variable_scope("disc", reuse=reuse) as disc:
x_in = tf.concat([x_future, x_past], 2)
conv1 = tf.layers.conv1d(inputs=x_in, filters=20, kernel_size=3, strides=1,
padding='same', activation=tf.nn.relu)
max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')
conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=3, kernel_size=2, strides=1,
padding='same', activation=tf.nn.relu)
max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')
# Flatten and add dropout
flat = tf.reshape(max_pool_2, (-1, 9))
flat = tf.nn.dropout(flat, keep_prob=self.keep_prob)
# Predictions
logits = tf.layers.dense(flat, 2)
return logits
そして、次のように呼び出します。
y_true = self.discriminator(x_past, x_future)
y_gen = self.discriminator(x_past, gen_future, reuse=True)
今、それは発散し始めました。それはなぜですか?