import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Conv2D, Flatten, Reshape, Input, Embedding
from keras.utils.np_utils import to_categorical
BATCH_SIZE = 64
ITERATIONS = 50000
(xtrain, ytrain), (_, _) = mnist.load_data()
strain = list(xtrain.shape)
strain.append(1)
xtrain = xtrain.reshape(strain) / 255 - 0.5
ytrain = to_categorical(ytrain, num_classes=11)
z_dim = 100
def build_generator():
# (z_dim, cls) -> img_shape
x1_input = Input(shape=(z_dim,))
x2_input = Input(shape=(1,))
x1 = Dense(strain[1] * strain[2], activation='relu')(x1_input)
x2 = Embedding(1, strain[1] * strain[2])(x2_input)[:, 0, :]
x = x1 + x2
x = Dense(strain[1] * strain[2], activation='sigmoid')(x)
y = Reshape((strain[1], strain[2], 1))(x)
model = Model((x1_input, x2_input), y)
return model
def build_discriminator():
# img_shape -> 11 categories
model = Sequential()
model.add(Conv2D(8, kernel_size=3, strides=2, padding='same', activation='relu', input_shape=strain[1:]))
model.add(Conv2D(16, kernel_size=3, strides=2, padding='same', activation='relu'))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(11, activation='softmax'))
return model
def build_gan(_generator, _discriminator):
# (z_dim, cls) -> 11 categories
x1_input = Input(shape=(z_dim,))
x2_input = Input(shape=(1,))
x = _generator((x1_input, x2_input))
y = _discriminator(x)
model = Model((x1_input, x2_input), y)
return model
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
generator = build_generator()
discriminator.trainable = False
gan = build_gan(generator, discriminator)
gan.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
def train_on_batch():
idx = np.random.randint(0, strain[0], (BATCH_SIZE,))
img = xtrain[idx]
label = ytrain[idx]
z = np.random.normal(0, 1, (BATCH_SIZE, z_dim))
z_cls = np.random.randint(0, 10, (BATCH_SIZE,))
gan_img = generator.predict(x=(z, z_cls))
gan_label = np.full((BATCH_SIZE,), 10)
gan_label = to_categorical(gan_label, num_classes=11)
gan_label_cls = to_categorical(z_cls, num_classes=11)
x = np.concatenate([img, gan_img])
y = np.concatenate([label, gan_label])
d_loss, _ = discriminator.train_on_batch(x=x, y=y)
g_loss, _ = gan.train_on_batch(x=(z, z_cls), y=gan_label_cls)
return d_loss, g_loss
def draw_sample():
fig, axes = plt.subplots(nrows=2, ncols=5)
cls = np.arange(0, 10)
z = np.random.normal(0, 1, (10, 100))
img = generator.predict((z, cls))
for i in range(2):
for j in range(5):
axes[i, j].imshow(img[5*i + j], cmap='gray')
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
plt.show()
def draw_learning_curve(_losses):
_losses = np.array(_losses)
plt.plot(_losses[:0], 'r', label='discriminator')
plt.plot(_losses[:1], 'b', label='generator')
plt.legend()
plt.show()
losses = []
checkpoints = []
for iteration in range(ITERATIONS):
print("-- Iteration: %d --" % (iteration + 1))
loss = train_on_batch()
if (iteration + 1) % 1000 == 0:
losses.append(loss)
checkpoints.append(iteration + 1)
print("Iteration: %d, D Loss: %.4f, G Loss: %.4f" % (iteration, loss[0], loss[1]))
draw_sample()
draw_learning_curve(losses)