4

Tensorflow で Generative Adversarial Network を実装しました。テスト時に生成された画像は、トレーニング中に使用したのと同じ batch_size (64) を使用して生成した場合、非常に良好です。一度に 1 つの画像を生成すると、結果は恐ろしいものになります。

考えられる原因は 2 です。

  • バッチ正規化?
  • 動的バッチサイズを取得するための tf.shape の間違った使用法

これが私のコードです:

from tensorflow.contrib.layers.python.layers import batch_norm

def conc(x, y):
    """Concatenate conditioning vector on feature map axis."""
    x_shapes = x.get_shape()
    y_shapes = y.get_shape()

    x0 = tf.shape(x)[0]
    x1 = x_shapes[1].value
    x2 = x_shapes[2].value
    y3 = y_shapes[3].value

    return tf.concat([x, y * tf.ones(shape=(x0,x1,x2,y3))], 3)

def batch_normal(input, scope="scope", reuse=False):
    return batch_norm(input, epsilon=1e-5, decay=0.9, scale=True, scope=scope, reuse=reuse, updates_collections=None)

def generator(z_var, y):

     y_dim = y.get_shape()[1].value

     z_var = tf.concat([z_var, y], 1)

     d1 = tf.layers.dense(z_var, 1024,
                     kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                     name='gen_fc1')

    d1 = tf.nn.relu(batch_normal(d1, scope='gen_bn1'))

    # add the second layer

    d1 = tf.concat([d1, y], 1)

    d2 = tf.layers.dense(d1, 7 * 7 * 128,
                     kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                     name='gen_fc2')

    d2 = tf.nn.relu(batch_normal(d2, scope='gen_bn2'))

    d2 = tf.reshape(d2, [-1, 7, 7, 128])
    y = tf.reshape(y, shape=[-1, 1, 1, y_dim])

    d2 = conc(d2, y)

    deconv1 = tf.layers.conv2d_transpose(d2, 64, (4, 4), strides=(2, 2), padding='same',
                                     kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                                     name='gen_deconv1')

    d3 = tf.nn.relu(batch_normal(deconv1, scope='gen_bn3'))

    d3 = conc(d3, y)

    deconv2 = tf.layers.conv2d_transpose(d3, 1, (4, 4), strides=(2, 2), padding='same',
                                     kernel_initializer=tf.random_normal_initializer(stddev=0.02),
                                     name='gen_deconv2')

    return tf.nn.sigmoid(deconv2)
4

2 に答える 2