栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

GAN原理详解(附代码)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

GAN原理详解(附代码)

G A N mathrm{GAN} GAN原理介绍

G A N mathrm{GAN} GAN是由生成器 G G G和判别器 D D D组成,通过大量样本数据训练使得生成器的生成能力和判别器的判别能力在对抗中逐步提高,最终目的是让生成器 G G G能够生成以假乱真的样本。具体的训练过程为:

  • 首先在真实数据集中采样出一批真实样本
  • 同时从某个分布中随机生成一些噪声向量
  • 接着将噪声向量输入到生成器中生成假的样本数据
  • 最后把真实样本与等量的假的样本输入到判别器中进行判别

    G A N mathrm{GAN} GAN训练的过程可以描述为求解一个二元函数极小极大值的过程,具体的公式如下所示: min ⁡ G max ⁡ D V ( D , G ) = min ⁡ G max ⁡ D E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] minlimits_{G}maxlimits_{D}V(D,G)=minlimits_{G}maxlimits_{D}mathbb{E}_{xsim p_{data}(x)}[log D(x)]+mathbb{E}_{zsim p_z(z)}[log (1-D(G(z)))] Gmin​Dmax​V(D,G)=Gmin​Dmax​Ex∼pdata​(x)​[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))]优化对抗损失 V ( D , G ) V(D,G) V(D,G)可以同时达到两个目的,第一个目的是让生成器 G G G能够生成真实的样本,第二个目的是让判别器 D D D能更好地区分开真实样本和生成样本。
生成器 G G G

训练生成器的损失函数其实是对抗损失 V ( G , D ) V(G,D) V(G,D)中关于噪声 z z z的项,其损失函数为: L G = E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_{G}=mathbb{E}_{zsim p_z(z)}[log (1-D(G(z)))] LG​=Ez∼pz​(z)​[log(1−D(G(z)))]生成器的目标是希望生成器生成的样本越像真的越好,具体体现在数学公式中为:
min ⁡ G E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ]    ⟹    min ⁡ G E z ∼ p z ( z ) [ − log ⁡ ( D ( G ( z ) ) ) ]    ⟹    min ⁡ G E z ∼ p z ( z ) − [ 1 ⋅ log ⁡ ( D ( G ( z ) ) ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − D ( G ( z ) ) ) ]    ⟹    min ⁡ G 1 N ∑ i = 1 N B C E ( D ( G ( z i ) ) , 1 ) begin{aligned}&minlimits_{G}mathbb{E}_{zsim p_z(z)}[log (1-D(G(z)))]\implies&minlimits_{G}mathbb{E}_{zsim p_z(z)}[-log (D(G(z)))]\ implies&minlimits_{G}mathbb{E}_{zsim p_z(z)}-[1 cdot log (D(G(z))) + (1-1)cdot log(1-D(G(z)))]\implies&minlimits_{G} frac{1}{N}sumlimits_{i=1}^N mathrm{BCE}(D(G(z_i)),1)end{aligned} ⟹⟹⟹​Gmin​Ez∼pz​(z)​[log(1−D(G(z)))]Gmin​Ez∼pz​(z)​[−log(D(G(z)))]Gmin​Ez∼pz​(z)​−[1⋅log(D(G(z)))+(1−1)⋅log(1−D(G(z)))]Gmin​N1​i=1∑N​BCE(D(G(zi​)),1)​其中 B C E ( ⋅ , ⋅ ) mathrm{BCE}(cdot,cdot) BCE(⋅,⋅)表示二元交叉熵函数。则由上可知, G A N mathrm{GAN} GAN中生成器最小化损失函数 L G L_G LG​可以写成最小化二元交叉熵函数的形式。

判别器 D D D

训练生成器的损失函数其实是对抗损失 V ( G , D ) V(G,D) V(G,D)中关于样本 x x x的项,其损失函数为: L D = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E x ^ ∼ p g ( x ^ ) [ log ⁡ ( 1 − D ( x ^ ) ) ] L_D=mathbb{E}_{xsim p_{data}(x)}[log D(x)]+mathbb{E}_{hat{x}sim p_g(hat{x})}[log (1-D(hat{x}))] LD​=Ex∼pdata​(x)​[logD(x)]+Ex^∼pg​(x^)​[log(1−D(x^))]判别器的目标是希望判别器能够更好的区分出生成样本和真实样本,具体体现在数学公式中为: max ⁡ D E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E x ^ ∼ p g ( x ^ ) [ log ⁡ ( 1 − D ( x ^ ) ) ]    ⟹    min ⁡ D E x ∼ p d a t a ( x ) − [ log ⁡ D ( x ) ] + min ⁡ D E x ^ ∼ p g ( x ^ ) − [ log ⁡ ( 1 − D ( x ^ ) ) ]    ⟹    min ⁡ D E x ∼ p d a t a ( x ) − [ 1 ⋅ log ⁡ ( D ( x ) ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − D ( x ) ) ] + min ⁡ D E x ^ ∼ p g ( x ^ ) − [ 0 ⋅ log ⁡ ( D ( x ^ ) ) + ( 1 − 0 ) ⋅ log ⁡ ( 1 − D ( x ^ ) ) ]    ⟹    min ⁡ D 1 N ∑ i = 1 N [ B C E ( D ( x i ) , 1 ) + B C E ( D ( x ^ i ) , 0 ) 2 ] begin{aligned}&maxlimits_{D}mathbb{E}_{xsim p_{data}(x)}[log D(x)]+mathbb{E}_{hat{x}sim p_g(hat{x})}[log (1-D(hat{x}))]\ implies& minlimits_{D}mathbb{E}_{xsim p_{data}(x)}-[log D(x)]+minlimits_{D}mathbb{E}_{hat{x}sim p_g(hat{x})}-[log (1-D(hat{x}))]\ implies&minlimits_{D}mathbb{E}_{xsim p_{data}(x)}-[1 cdot log (D(x)) + (1-1)cdot log(1-D(x))]\+& minlimits_{D}mathbb{E}_{hat{x}sim p_{g}(hat{x})}-[0 cdot log (D(hat{x})) + (1-0)cdot log(1-D(hat{x}))]\ implies&minlimits_{D} frac{1}{N}sumlimits_{i=1}^{N} left[frac{mathrm{BCE}(D(x_i),1)+mathrm{BCE}(D(hat{x}_i),0)}{2}right]end{aligned} ⟹⟹+⟹​Dmax​Ex∼pdata​(x)​[logD(x)]+Ex^∼pg​(x^)​[log(1−D(x^))]Dmin​Ex∼pdata​(x)​−[logD(x)]+Dmin​Ex^∼pg​(x^)​−[log(1−D(x^))]Dmin​Ex∼pdata​(x)​−[1⋅log(D(x))+(1−1)⋅log(1−D(x))]Dmin​Ex^∼pg​(x^)​−[0⋅log(D(x^))+(1−0)⋅log(1−D(x^))]Dmin​N1​i=1∑N​[2BCE(D(xi​),1)+BCE(D(x^i​),0)​]​其中 x x x表示真实的数据样本, x ^ hat{x} x^表示生成器生成的样本。由上可以发现, G A N mathrm{GAN} GAN中判别器最大化损失函数 L D L_D LD​可以写成最小化两个二元交叉熵函数的形式。

G A N mathrm{GAN} GAN代码介绍

本节介绍用 G A N mathrm{GAN} GAN和 D C G A N mathrm{DCGAN} DCGAN生成 m n i s t mathrm{mnist} mnist手写数据集, G A N mathrm{GAN} GAN和 D C G A N mathrm{DCGAN} DCGAN的网络结构以及完整的实现代码分别如下所示:

import argparse
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.utils import save_image

def args_parse():
	parser = argparse.ArgumentParser()
	parser.add_argument("--n_epoches", type=int, default=100, help="number of epochs of training")
	parser.add_argument("--batch_size", type=int, default=256, help="size of the batches")
	parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
	parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation")
	parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
	parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
	parser.add_argument("--channels", type=int, default=1, help="number of image channels")
	parser.add_argument("--sample_interval", type=int, default=100, help="interval between image sampling")
	parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
	parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
	parser.add_argument("--type", type=str, default='GAN', help="The type of GAN")
	return parser.parse_args()

class Generator(nn.Module):
	def __init__(self,	latent_dim, img_shape):
		super(Generator, self).__init__()
		self.img_shape = img_shape

		def block(in_feat, out_feat, normalize=True):
			layers = [nn.Linear(in_feat, out_feat)]
			if normalize:
				layers.append(nn.BatchNorm1d(out_feat, 0.8))
			layers.append(nn.LeakyReLU(0.2, inplace=True))
			return layers

		self.model = nn.Sequential(
				*block(latent_dim, 128, normalize=False),
				*block(128, 256),
				*block(256, 512),
				*block(512, 1024),
				nn.Linear(1024, int(np.prod(img_shape))),
				nn.Tanh()
		)


	def forward(self, z):
		img = self.model(z)
		img = img.view(img.size(0), self.img_shape[0], self.img_shape[1], self.img_shape[2])
		return img

class Discriminator(nn.Module):
	def __init__(self, img_shape):
		super(Discriminator, self).__init__()

		self.model = nn.Sequential(
			nn.Linear(int(np.prod(img_shape)), 512),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Linear(512, 256),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Linear(256, 1),
			nn.Sigmoid(),
		)

	def forward(self, img):
		img_flat = img.view(img.size(0), -1)
		validity = self.model(img_flat)
		return validity


class Generator_CNN(nn.Module):
	def __init__(self,	latent_dim, img_shape):
		super(Generator_CNN, self).__init__()

		self.init_size = img_shape[1] // 4  
		self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))   # 100 ——> 128 * 8 * 8 = 8192

		self.conv_blocks = nn.Sequential(
			nn.BatchNorm2d(128),
			nn.Upsample(scale_factor = 2),
			nn.Conv2d(128, 128, 3, stride=1, padding=1),
			nn.BatchNorm2d(128, 0.8),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Upsample(scale_factor = 2),
			nn.Conv2d(128, 64, 3, stride=1, padding=1),
			nn.BatchNorm2d(64, 0.8),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Conv2d(64, img_shape[0], 3, stride=1, padding=1),
			nn.Tanh()
		)

	def forward(self, z):
		out = self.l1(z)
		out = out.view(out.shape[0], 128, self.init_size, self.init_size)
		img = self.conv_blocks(out)
		return img

class Discriminator_CNN(nn.Module):
	def __init__(self, img_shape):
		super(Discriminator_CNN, self).__init__()

		def discriminator_block(in_filters, out_filters, bn=True):
			block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
					 nn.LeakyReLU(0.2,inplace=True),
					 nn.Dropout2d(0.25)]
			if bn:
				block.append(nn.BatchNorm2d(out_filters, 0.8))
			return block

		self.model = nn.Sequential(
			*discriminator_block(img_shape[0], 16, bn=False),
			*discriminator_block(16, 32),
			*discriminator_block(32, 64),
			*discriminator_block(64, 128),

		)

		ds_size = img_shape[1] // 2 ** 4
		self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) # 128 * 2 * 2 ——> 1

	def forward(self, img):
		out = self.model(img) 
		out = out.view(out.shape[0], -1)
		validity = self.adv_layer(out)
		return validity

def train():
	opt = args_parse()

	transform = transforms.Compose(
				[
					transforms.Resize(opt.img_size),
					transforms.ToTensor(),
					transforms.Normalize([0.5],[0.5])	
				])

	mnist_data = datasets.MNIST(
					"mnist-data", 
					train=True, 
					download=True,
					transform = transform
				)

	train_loader = torch.utils.data.DataLoader(
					mnist_data, 
					batch_size=opt.batch_size, 
					shuffle=True)


	img_shape = (opt.channels, opt.img_size, opt.img_size)
	# Construct generator and discriminator
	if opt.type == 'DCGAN':
		generator = Generator_CNN(opt.latent_dim, img_shape)
		discriminator = Discriminator_CNN(img_shape)
	else:
		generator = Generator(opt.latent_dim, img_shape)
		discriminator = Discriminator(img_shape)

	adversarial_loss = torch.nn.BCELoss()

	cuda = True if torch.cuda.is_available() else False
	if cuda:
		generator.cuda()
		discriminator.cuda()
		adversarial_loss.cuda()

	# Loss function


	# Optimizers
	optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
	optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

	print(generator)
	print(discriminator)

	Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
	for epoch in range(opt.n_epoches):
		for i, (imgs, _) in enumerate(train_loader):
			# adversarial ground truths
			valid = torch.ones(imgs.shape[0], 1).type(Tensor)
			fake = torch.zeros(imgs.shape[0], 1).type(Tensor)

			real_imgs = imgs.type(Tensor)

			#############    Train Generator    ################
			optimizer_G.zero_grad()

			# sample noise as generator input
			z = torch.tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).type(Tensor)

			# Generate a batch of images
			gen_imgs = generator(z)

			# G-Loss
			g_loss = adversarial_loss(discriminator(gen_imgs), valid)
			g_loss.backward()
			optimizer_G.step()

			#############  Train Discriminator ################
			optimizer_D.zero_grad()

			# D-Loss
			real_loss = adversarial_loss(discriminator(real_imgs), valid)
			fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
			d_loss = (real_loss + fake_loss) / 2

			d_loss.backward()
			optimizer_D.step()

			print(
				"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G Loss: %f]" 
				% (epoch, opt.n_epoches, i , len(train_loader), d_loss.item(), g_loss.item())
			)

			batches_done = epoch * len(train_loader) + i 
			os.makedirs("images", exist_ok=True)
			if batches_done % opt.sample_interval == 0:
				save_image(gen_imgs.data[:25], "images/%d.png" % (batches_done), nrow=5, normalize=True )

if __name__ == '__main__':
	train()
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/655434.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号