tensorflow.keras (ver 2.1) を使用して DCGAN をトレーニングしたいと考えています。
公式チュートリアル ( https://www.tensorflow.org/tutorials/generative/dcgan ) に従ったところ、公式コードは正常にトレーニングされました。
しかし、以下のように書き換えようとすると、トレーニング結果が失敗しました。
結果はノイズのように見え、トレーニングの反復に関係なく、損失はほぼサンプル値です。
何が原因かわかりません...
%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)
import argparse
import cv2
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
from tensorflow.keras import layers
from tensorflow.keras.layers import *
from tensorflow.keras.initializers import RandomNormal as RN, Constant
import pickle
import os
# config
class_N = 2
img_height, img_width = 32, 32
channel = 3
# GAN config
Z_dim = 100
# model path
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
def Generator():
inputs = Input((Z_dim,))
in_h = int(img_height / 16)
in_w = int(img_width / 16)
base = 128
# 1/16
x = Dense(in_h * in_w * base, name='g_dense1',
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(inputs)
x = Reshape((in_h, in_w, base), input_shape=(base * in_h * in_w,))(x)
x = Activation('relu')(x)
x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_dense1_bn')(x)
# 1/8
x = Conv2DTranspose(base*4, (5, 5), name='g_conv1', padding='same', strides=(2,2),
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
x = Activation('relu')(x)
x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_conv1_bn')(x)
# 1/4
x = Conv2DTranspose(base*2, (5, 5), name='g_conv2', padding='same', strides=(2,2),
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
x = Activation('relu')(x)
x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_conv2_bn')(x)
# 1/2
x = Conv2DTranspose(base, (5, 5), name='g_conv3', padding='same', strides=(2,2),
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
x = Activation('relu')(x)
x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_conv3_bn')(x)
# 1/1
x = Conv2DTranspose(channel, (5, 5), name='g_out', padding='same', strides=(2,2),
kernel_initializer=RN(mean=0.0, stddev=0.02), bias_initializer=Constant())(x)
x = Activation('tanh')(x)
model = tf.keras.Model(inputs=inputs, outputs=x, name='G')
return model
def Discriminator():
base = 32
inputs = Input((img_height, img_width, channel))
x = Conv2D(base, (5, 5), padding='same', strides=(2,2), name='d_conv1',
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(inputs)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(base*2, (5, 5), padding='same', strides=(2,2), name='d_conv2',
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(base*4, (5, 5), padding='same', strides=(2,2), name='d_conv3',
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(base*8, (5, 5), padding='same', strides=(2,2), name='d_conv4',
kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Flatten()(x)
x = Dense(1, name='d_out',
kernel_initializer=RN(mean=0.0, stddev=0.02), bias_initializer=Constant())(x)
model = tf.keras.Model(inputs=inputs, outputs=x, name='D')
return model
def load_cifar10():
path = 'drive/My Drive/Colab Notebooks/' + 'cifar-10-batches-py'
if not os.path.exists(path):
os.system("wget {}".format(path))
os.system("tar xvf {}".format(path))
# train data
train_x = np.ndarray([0, 32, 32, 3], dtype=np.float32)
train_y = np.ndarray([0, ], dtype=np.int)
for i in range(1, 6):
data_path = path + '/data_batch_{}'.format(i)
with open(data_path, 'rb') as f:
datas = pickle.load(f, encoding='bytes')
print(data_path)
x = datas[b'data']
x = x.reshape(x.shape[0], 3, 32, 32)
x = x.transpose(0, 2, 3, 1)
train_x = np.vstack((train_x, x))
y = np.array(datas[b'labels'], dtype=np.int)
train_y = np.hstack((train_y, y))
# test data
data_path = path + '/test_batch'
with open(data_path, 'rb') as f:
datas = pickle.load(f, encoding='bytes')
print(data_path)
x = datas[b'data']
x = x.reshape(x.shape[0], 3, 32, 32)
test_x = x.transpose(0, 2, 3, 1)
test_y = np.array(datas[b'labels'], dtype=np.int)
return train_x, train_y, test_x, test_y
# train
def train():
# model
G = Generator()
D = Discriminator()
train_x, train_y, test_x, test_y = load_cifar10()
xs = train_x / 127.5 - 1
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# training
mb = 64
mbi = 0
train_N = len(xs)
train_ind = np.arange(train_N)
np.random.seed(0)
@tf.function
def train_iter(x, z):
with tf.GradientTape() as G_tape, tf.GradientTape() as D_tape:
# feed forward
# z -> G -> Gz
Gz = G(z, training=True)
# x -> D -> Dx
# z -> G -> Gz -> D -> DGz
Dx = D(x, training=True)
DGz = D(Gz, training=True)
# get loss
loss_G = loss_fn(tf.ones_like(Gz), Gz)
loss_D_real = loss_fn(tf.ones_like(Dx), Dx)
loss_D_fake = loss_fn(tf.zeros_like(DGz), DGz)
loss_D = loss_D_real + loss_D_fake
# feed back
gradients_of_G = G_tape.gradient(loss_G, G.trainable_variables)
gradients_of_D = D_tape.gradient(loss_D, D.trainable_variables)
# update parameter
G_optimizer.apply_gradients(zip(gradients_of_G, G.trainable_variables))
D_optimizer.apply_gradients(zip(gradients_of_D, D.trainable_variables))
return loss_G, loss_D
#with strategy.scope():
# optimizer
G_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
D_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
checkpoint = tf.train.Checkpoint(G_optimizer=G_optimizer, D_optimizer=D_optimizer, G=G, D=D)
for ite in range(10000):
if mbi + mb > train_N:
mb_ind = train_ind[mbi:]
np.random.shuffle(train_ind)
mb_ind = np.hstack((mb_ind, train_ind[:(mb - (train_N - mbi))]))
mbi = mb - (train_N - mbi)
else:
mb_ind = train_ind[mbi: mbi+mb]
mbi += mb
x = xs[mb_ind]
z = np.random.uniform(-1, 1, size=(mb, Z_dim))
#z = tf.random.normal([mb, Z_dim])
loss_G, loss_D = train_iter(x, z)
if (ite + 1) % 100 == 0:
print("iter >>", ite+1, ',G:loss >>', loss_G.numpy(), ',D:loss >>', loss_D.numpy())
# display generated image
if (ite + 1) % 1000 == 0:
Gz = G(z)
_Gz = (Gz * 127.5 + 127.5).numpy().astype(int)
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(_Gz[i])
plt.axis('off')
plt.show()
# save model
checkpoint.save(file_prefix = checkpoint_prefix)