栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

实验:打造自己的MNIST-GAN

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

实验:打造自己的MNIST-GAN

实验:打造自己的MNIST-GAN

文章目录
  • 实验:打造自己的MNIST-GAN
    • 1 实验内容
    • 2 实验原理
      • Basic Idea of GAN
      • Algorithm
    • 3 具体实现
      • 使用原生GAN实现
        • 加载MNIST数据
        • 构建生成器
        • 构建判别器
        • 损失函数与优化
        • 随机采样
        • 交替训练
        • 生成结果
      • 使用CNN+GAN实现
        • 更改生成网络结构
        • 更改判别网络结构
        • 训练过程
        • 生成结果
        • 观察linearly interpolating结果
      • 使用CGAN实现
        • 更改生成网络结构
        • 更改判别网络结构
        • 交替训练
        • 生成结果:

1 实验内容

借助Keras,Tensorfolow 或Pytorch 等框架,设计和搭建自己的MNIST-GAN 图像生成器,生成新的手写数字图片

要求:

  • 实现MNIST 数据加载和可视化

  • 搜索和阅读相关资料和论文,在Keras,Tensorfolow或Pytorch 任意框架下实现MNIST-GAN网络的构建和训练

  • 使用训练好的MNIST-GAN 网络产生新的0-9 手写数字图片,并在训练数据集中找出和新生成图片‘‘最接近’’(可自行定义接近程度,或者尝试多种方式后人工比较)的训练图片

  • 使用linearly interpolating 完成下图中效果(图片来源:Figure 3 in Generative Adversarial Nets, Ian J. Goodfellow, et al.)

  • (选做)GAN 的训练被认为相对困难(可参见‘‘参考资料’’),总结在实验中遇到的问题,搜索资料,尝试不同的解决方案并总结

2 实验原理 Basic Idea of GAN

Algorithm

3 具体实现 使用原生GAN实现 加载MNIST数据
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),

    batch_size=opt.batch_size,
    shuffle=True,
)

这里随机取几张图片观察。

def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
    
mnist = datasets.MNIST("../../data/mnist")

构建生成器

仿照下图的原生GAN的结构来搭建。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-un1cmauM-1634715042902)(https://i.loli.net/2021/10/19/HYN87qkdefZhmyl.png)]

我们的生成器包含5个全连接层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

结构如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-b7qAO5Zw-1634715042904)(https://i.loli.net/2021/10/18/ZsElTonhgqWweQv.png)]

构建判别器

仿照原生GAN,使用全连接网络,把Maxout激活函数换为ReLU与Sigmoid。

包含3个全连接层,使用LeakyReLU和Sigmoid激活函数。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
    
discriminator = Discriminator()
print(discriminator)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TpTQFtR8-1634715042908)(https://i.loli.net/2021/10/18/4ALVpd8lhnOPWzi.png)]

损失函数与优化

判别器使用 Binary Cross Entropy Loss。

优化都使用Adam,lr = 0.0002。

optimizer_G = torch**.**optim**.**Adam(generator**.**parameters(), lr=opt**.**lr, betas=(opt**.**b1, opt**.**b2))

optimizer_D = torch**.**optim**.**Adam(discriminator**.**parameters(), lr=opt**.**lr, betas=(opt**.**b1, opt**.**b2))
随机采样

从100维的正态分布中采样作为z。

一个batch有64组输入。

z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
交替训练
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

real_imgs = Variable(imgs.type(Tensor))

#更新生成器

optimizer_G.zero_grad()

#采样z
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
gen_imgs = generator(z)

#生成器权值更新
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()

#更新判别器
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
生成结果

每400次迭代观察一次当前生成图像。

最开始,生成全是杂讯。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NHUs123v-1634715042910)(https://i.loli.net/2021/10/19/gdMUDkeC42OJ6sR.png)]

开始设置的epoch数很少,结果很差,下图是第6000次迭代的结果:

20000次:

100000次:

200个epoch以后,也就是十八万多次迭代以后的最终结果:

感觉没有很好的结果,还需要继续train下去,但没有继续尝试了。

使用CNN+GAN实现 更改生成网络结构
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
        
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

网络结构为:

更改判别网络结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

网络结构为:

训练过程

生成结果

比用原生GAN的结果好很多。

比如:

第6000次迭代:

第20000次迭代:

]

第100个epoch:

第120个epoch:

观察linearly interpolating结果

随机选两个点,在两点中取10个点观察变化过程:

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
g = torch.load('model/generator.pkl')
z = Variable(Tensor(np.random.normal(0, 1, (2, 100))))
a = torch.FloatTensor(100, 20)
for i in range(100):
    a[i] = torch.linspace(z[0][i], z[1][i], 10)

b = Variable(a.t())
b = b.to('cuda')
gen_imgs = g(b)
save_image(gen_imgs.data[:], "images_trans.png", normalize=True)

再次尝试观察更细致的变化:

]

使用CGAN实现

为了可以控制输出我们可以使用CGAN

在原生GAN结构基础上,更改网络结构如下:

更改生成网络结构

更改判别网络结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

结构如下:

交替训练

把标签引入训练。

    batch_size = imgs.shape[0]
    valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真
    fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假
    
    optimizer_G.zero_grad()
    gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
    
	#更新生成器
    gen_imgs = generator(z, gen_labels)
    print("gen_imgs =")
    for img in gen_imgs[:3]:
        show_img(img)

    validity = discriminator(gen_imgs, gen_labels)
    g_loss = adversarial_loss(validity, valid)
    print("g_loss =", g_loss, 'n')

    g_loss.backward()
    optimizer_G.step()

   #更新判别器

    optimizer_D.zero_grad()

    validity_real = discriminator(real_imgs, labels)
    d_real_loss = adversarial_loss(validity_real, valid)
    validity_fake = discriminator(gen_imgs.detach(), gen_labels)
    
    d_fake_loss = adversarial_loss(validity_fake, fake)
    d_loss = (d_real_loss + d_fake_loss) / 2
    print("real_loss =", d_real_loss, 'n')
    print("fake_loss =", d_fake_loss, 'n')
    print("d_loss =", d_loss, 'n')    
    
    d_loss.backward()
    optimizer_D.step()
生成结果:

100个epoch后的结果

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/339054.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号