栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

[GAN实战] DCGAN实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

[GAN实战] DCGAN实现

'''
>>> import tensorflow as tf;tf.__version__
2021-11-07 23:53:04.980446: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
'2.3.0'
>>> import tensorflow_datasets as tfds;tfds.__version__
'4.3.0'
>>> 
'''
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.metrics import binary_accuracy
import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt

# load datasets
ds_train, ds_info = tfds.load('fashion_mnist', split='train', shuffle_files=True, with_info=True)
fig = tfds.show_examples(ds_info, ds_train)

batch_size = 200
image_shape = (28, 28, 1)

def preprocess(features):
    image = tf.image.resize(features['image'], image_shape[:2])    
    image = tf.cast(image, tf.float32)
    image = (image-127.5)/127.5
    return image


ds_train = ds_train.map(preprocess)
ds_train = ds_train.cache() # put dataset into memory
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size).repeat()

train_num = ds_info.splits['train'].num_examples
train_steps_per_epoch = round(train_num/batch_size)
print(train_steps_per_epoch)

class GAN():
    def __init__(self, generator, discriminator):        
        # discriminator
        self.D = discriminator
        self.G = generator

        self.bce = tf.keras.losses.BinaryCrossentropy()
        self.d_loss = {}
        self.g_loss = {}
        self.accuracy = {}        
        self.g_gradients = []

    def discriminator_loss(self, pred_fake, pred_real):
        real_loss = self.bce(tf.ones_like(pred_real), pred_real)
        fake_loss = self.bce(tf.zeros_like(pred_fake), pred_fake)
        
        d_loss = 0.5*(real_loss + fake_loss)
        return d_loss
    
    def generator_loss(self, pred_fake):
        g_loss = self.bce(tf.ones_like(pred_fake), pred_fake)
        return g_loss
    
    def train_step(self, g_input, real_input):

        with tf.GradientTape() as g_tape,
             tf.GradientTape() as d_tape:
            # Feed forward
            fake_input = self.G(g_input)

            pred_fake = self.D(fake_input)
            pred_real = self.D(real_input)

            # Calculate losses
            d_loss = self.discriminator_loss(pred_fake, pred_real)
            g_loss = self.generator_loss(pred_fake)
            
            # Accuracy
            fake_accuracy = tf.math.reduce_mean(binary_accuracy(tf.zeros_like(pred_fake), pred_fake))
            real_accuracy = tf.math.reduce_mean(binary_accuracy(tf.ones_like(pred_real), pred_real))
            
            # backprop gradients
            gradient_g = g_tape.gradient(g_loss, self.G.trainable_variables)
            gradient_d = d_tape.gradient(d_loss, self.D.trainable_variables)
            
            gradient_g_l1_norm = [tf.norm(gradient).numpy() for gradient in gradient_g]
            self.g_gradients.append(gradient_g_l1_norm) 
            # update weights
            self.G_optimizer.apply_gradients(zip(gradient_g, self.G.trainable_variables))
            self.D_optimizer.apply_gradients(zip(gradient_d, self.D.trainable_variables))


        return g_loss, d_loss, fake_accuracy, real_accuracy
    
    def train(self, data_generator, 
                    z_generator,
                    g_optimizer, d_optimizer,
                    steps, interval=100):
        self.D_optimizer = d_optimizer
        self.G_optimizer = g_optimizer          
        val_g_input = next(z_generator)
        for i in range(steps):
            g_input = next(z_generator)
            real_input = next(data_generator)
            
            g_loss, d_loss, fake_accuracy, real_accuracy = self.train_step(g_input, real_input)
            self.d_loss[i] = d_loss.numpy()
            self.g_loss[i] = g_loss.numpy()
            self.accuracy[i] = 0.5*(fake_accuracy.numpy() + real_accuracy.numpy())
            if i%interval == 0:
                msg = "Step {}: d_loss {:.4f} g_loss {:.4f} Accuracy. real : {:.3f} fake : {:.3f}"
                .format(i, d_loss, g_loss, real_accuracy, fake_accuracy)
                print(msg)
                
                fake_images = self.G(val_g_input)
                self.plot_images(fake_images)

    def plot_images(self, images):
        pass

class DCGAN(GAN):
    def __init__(self, z_dim, input_shape):
        
        discriminator = self.Discriminator(input_shape)
        generator = self.Generator(z_dim)
        
        GAN.__init__(self, generator, discriminator)
        
    def Discriminator(self, input_shape): 

        model = tf.keras.Sequential(name='Discriminator') 
        model.add(layers.Input(shape=input_shape)) 

        model.add(layers.Conv2D(32, 3, strides=(2,2), padding='same'))
        model.add(layers.BatchNormalization(momentum=0.9))
        model.add(layers.LeakyReLU(0.2)) 
        model.add(layers.Dropout(0.2))

        model.add(layers.Conv2D(64, 3, strides=(2,2), padding='same')) 
        model.add(layers.BatchNormalization(momentum=0.9)) 
        model.add(layers.LeakyReLU(0.2))
        model.add(layers.Dropout(0.2))

        model.add(layers.Flatten()) 
        model.add(layers.Dense(1, activation='sigmoid')) 

        return model 

    def Generator(self, z_dim): 

        model = tf.keras.Sequential(name='Generator') 
        model.add(layers.Input(shape=[z_dim])) 

        model.add(layers.Dense(7*7*64))        
        model.add(layers.BatchNormalization(momentum=0.9)) 
        model.add(layers.ReLU())
        model.add(layers.Reshape((7,7,64))) 

        model.add(layers.Conv2D(64, 3, padding='same')) 
        model.add(layers.BatchNormalization(momentum=0.9)) 
        model.add(layers.ReLU())         
        model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))

        model.add(layers.Conv2D(32, 3, padding='same')) 
        model.add(layers.ReLU()) 
        model.add(layers.UpSampling2D((2,2), interpolation="bilinear")) 

        model.add(layers.Conv2D(image_shape[-1], 3, padding='same', activation='tanh')) 

        return model     
    
    def plot_images(self, images):   
        grid_row = 1
        grid_col = 8
        f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*1.5, grid_row*1.5))
        for col in range(grid_col):
            axarr[col].imshow((images[col,:,:,0]+1)/2, cmap='gray')
            axarr[col].axis('off') 
        plt.show()

z_dim = 100


def z_generator(batch_size, z_dim):
    while True:
         yield tf.random.normal((batch_size, z_dim))        
            
z_gen = z_generator(batch_size, z_dim)


#gan.D.summary()
#gan.G.summary()

gan = DCGAN(z_dim, image_shape)

gan.D.summary()
'''
Model: "Discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 14, 14, 32)        320       
_________________________________________________________________
batch_normalization (BatchNo (None, 14, 14, 32)        128       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 14, 14, 32)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 64)          256       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 7, 7, 64)          0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 3136)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 3137      
=================================================================
Total params: 22,337
Trainable params: 22,145
Non-trainable params: 192
'''

gan.G.summary()
'''
Model: "Generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 3136)              316736    
_________________________________________________________________
batch_normalization_2 (Batch (None, 3136)              12544     
_________________________________________________________________
re_lu (ReLU)                 (None, 3136)              0         
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 64)          36928     
_________________________________________________________________
batch_normalization_3 (Batch (None, 7, 7, 64)          256       
_________________________________________________________________
re_lu_1 (ReLU)               (None, 7, 7, 64)          0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 14, 14, 32)        18464     
_________________________________________________________________
re_lu_2 (ReLU)               (None, 14, 14, 32)        0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 28, 28, 32)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 1)         289       
=================================================================
Total params: 385,217
Trainable params: 378,817
Non-trainable params: 6,400
'''

gan.train(iter(ds_train), z_gen, 
          RMSprop(3e-4), RMSprop(3e-4),
          25*train_steps_per_epoch, 
          2*train_steps_per_epoch)

#plt.figure(figsize=(10,6))
fig, (ax1, ax2) = plt.subplots(2, sharex=True)
fig.set_figwidth(10)
fig.set_figheight(8)
ax1.plot(list(gan.d_loss.values())[:5000], label='D loss', alpha=0.7)
ax1.set_title("D loss")
ax2.plot(list(gan.g_loss.values())[:5000], label='G loss', alpha=0.7)
ax2.set_title("G loss")
#ax3.plot([grad[0] for grad in gan.g_gradients], label='G loss', alpha=0.7)
#ax3.set_title("Gradient")

plt.xlabel('Steps')

训练过程:

损失函数:

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/444538.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号