栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

基于最基础的GAN生成动漫头像

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

基于最基础的GAN生成动漫头像

最近在学习换脸相关的技术 在看完FaceShifter的论文和代码后就对GAN的思想产生了兴趣。而在看完陈云大佬的《深度学习框架PyTorch 入门与实践》GAN那一章后 就产生了用GAN生成动漫头像的念头。

基本思想

GAN 又叫生成对抗网络是一种非监督的学习。该网络中有一个生成器 Generators 和判别器 Discriminators 而训练过程就是这两个网络不断博弈对抗的过程。生成器不断生成假图企图通过判别器的识别 而判别器则将图片划分为真实图像和生成图像。

在该模型中 生成器的输入是一串噪音 输出是一张生成的假图 而生成器致力于让判别器无法识别出这张假图是生成的图还是真实图像。在训练过程中 不断用判别器的分数做反馈使生成器效果越来越好。判别器的输入是一张图片 输出则是图片的分数 分数越高说明此时生成的图像越接近真实图像。判别器致力于识别图片是真图还是假图 在训练过程中不断投喂假图 输出一个分数再与真实图像的标签进行比较。实际上也是一个二分类的过程。 代码实现 获取数据

网上做这种模型的人非常多 所以动漫头像的数据集也非常多。不过我大致看了一下 网上的数据集中的动漫头像都非常古老 颇有90年代日本动漫的画风 可能都是老二次元 。在这个模型中 我用的是自己在一个网站爬下来的数据。网站链接在这/ | konachan.net - Konachan.com Anime Wallpapers。

import time
import requests
import tqdm
from bs4 import BeautifulSoup
import os
import traceback# python异常模块
# 爬取图片
def download(url,filename,proxies):
 # 判断此时文件是否存在
 if os.path.exists(filename):
 print( file exists )
 return
 try:
 time.sleep(1)
 r requests.get(url,stream True,timeout 60,proxies proxies)# 以流数据形式请求
 r.raise_for_status()
 with open(filename, wb ) as f:
 for chunk in r.iter_content(chunk_size 1024):
 if chunk:# 当这个文件存在时
 f.write(chunk)
 f.flush()
 return filename
 except KeyboardInterrupt:
 if os.path.exists(filename):# 此时出错说明该文件不存在任何数据 若保存过该文件则删除
 os.remove(filename)
 raise KeyboardInterrupt
 except Exception:
 traceback.print_exc()# 把返回信息输出到控制台
 if os.path.exists(filename):
 os.remove(filename)
if os.path.exists( imgs ) is False:
 os.makedirs( imgs )
proxy 127.0.0.1:58591 #
proxies {
 http : http:// proxy,
 https : https:// proxy
start 1
end 8000# 8k张图片
for i in tqdm.tqdm(range(start,end 1),desc download anime picture ing ~ ):# tqdm括号内的必须是一个迭代器
 time.sleep(1)
 url https://konachan.net/post?page %d tags % i# 网站
 html requests.get(url,verify True, proxies proxies).text# 获取html网页上的内容
 soup BeautifulSoup(html, html.parser )
 for img in soup.find_all( img ,class_ preview ):# 找到原网站中含有图片文件网站
 target_url img[ src ]
 filename os.path.join( imgs/true_imgs ,target_url.split( / )[-1])
 download(target_url,filename,proxies)

可能是网站的原因 若不加sleep()会返回连接超时的报错 我猜可能是访问的太频繁了。不过具体原因我也不太清楚 对爬虫这一块不是很熟悉。

从这个网站爬下来的图片都是一些动漫壁纸 可能有些包含人物 而有些不包含。这里我用了openCV的一块模块来识别图像中的头像 并把它截取下来做为接下来训练的数据。

头像数据
# 从动漫壁纸中截取人物头像
import cv2
import sys
import os
from glob import glob
def detect(filename,cascade_file lbpcascade_animeface.xml ):
 if not os.path.isfile(cascade_file):
 raise RuntimeError( %s: not found % cascade_file)
 cascade cv2.CascadeClassifier(cascade_file)# 目标检测
 image cv2.imread(filename)# 打开图片
 gray cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
 gray cv2.equalizeHist(gray)
 faces cascade.detectMultiScale(gray,
 scaleFactor 1.1,
 minNeighbors 5,
 minSize (48,48))
 for i ,(x,y,w,h) in enumerate(faces):
 face image[y:y h,x:x w,:]# 得到图像像素点的分布
 face cv2.resize(face,(96,96))
 save_filename {}-{}.jpg .format(os.path.basename(filename).split( . )[0],i)
 cv2.imwrite( data/faces/ save_filename,face)# 写入文件
if __name__ __main__ :
 if os.path.exists( data/faces ) is False:
 os.makedirs( data/faces )
 file_list glob( imgs/true_imgs/*.jpg )# 将imgs中所有图片路径整合为一个迭代器
 for filename in file_list:
 detect(filename)

这两块 我都借鉴了这位大佬的代码利用GAN生成动漫头像_一个追逐自我的程序员的博客-CSDN博客

网络结构
# 生成器
class NetG(nn.Module):
 def __init__(self,opt):
 super(NetG,self).__init__()
 ngf opt.ngf# 生成器feature map数
 # 生成器主要的网络模块
 self.main nn.Sequential(
 # 输入是一个nz维的噪音 是一个随机生成的张量 可以认为是大小为1x1的feature amp
 nn.ConvTranspose2d(opt.nz,ngf*8,kernel_size 4,stride 1,padding 0,bias False),# 反卷积 做上采样
 nn.BatchNorm2d(ngf*8),
 nn.ReLU(True),
 # 上一步的输出形状 (ngf*8) x 4 x 4
 nn.ConvTranspose2d(ngf*8,ngf*4,kernel_size 4,stride 2,padding 1,bias False),# 继续上采样 不断减小图片维度
 nn.BatchNorm2d(ngf*4),
 nn.ReLU(True),
 # 上一步的输出形状 (ngf*4) x 8 x 8
 nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size 4,stride 2,padding 1,bias False),
 nn.BatchNorm2d(ngf*2),
 nn.ReLU(True),
 # 上一步的输出形状 (ngf*2) x 16 x 16
 nn.ConvTranspose2d(ngf*2,ngf,kernel_size 4,stride 2,padding 1,bias False),
 nn.BatchNorm2d(ngf),
 nn.ReLU(True),
 # 上一步的输出形状 (ngf) x 32 x 32
 nn.ConvTranspose2d(ngf,3,kernel_size 5,stride 3,padding 1,bias False),
 nn.Tanh()# 输出范围固定在 -1 ~ 1故而采用Tanh
 # 输出形状 3 x 96 x 96
 # feature map经过解码过程 最后生成一个图片
 def forward(self,input):
 return self.main(input)
# 判别器
class NetD(nn.Module):
 def __init__(self,opt):
 super(NetD,self).__init__()
 ndf opt.ndf
 self.main nn.Sequential(
 # 输入3*96*96即生成器生成的图片
 nn.Conv2d(3,ndf,kernel_size 5,stride 3,padding 1,bias False),# 卷积 下采样 也是编码的过程
 nn.LeakyReLU(0.2,inplace True),
 # 输出 ndf x 32 x32
 nn.Conv2d(ndf,ndf*2,kernel_size 4,stride 2,padding 1,bias False),# 正好将feature map图片大小缩小一半
 nn.BatchNorm2d(ndf*2),
 nn.LeakyReLU(0.2,inplace True),
 # 输出 (ndf*2) x 16 x 16
 nn.Conv2d(ndf*2,ndf*4,kernel_size 4,stride 2,padding 1,bias False),
 nn.BatchNorm2d(ndf*4),
 nn.LeakyReLU(0.2,inplace True),
 # 输出 (ndf*4) x 8 x 8
 nn.Conv2d(ndf*4,ndf*8,kernel_size 4,stride 2,padding 1,bias False),
 nn.BatchNorm2d(ndf*8),
 nn.LeakyReLU(0.2,inplace True),
 # 输出 (ndf*8) x 4 x 4
 nn.Conv2d(ndf*8,1,kernel_size 4,stride 1,padding 0,bias False),# 最后编码成为一个维度为1的向量
 nn.Sigmoid()# 最后用Sigmoid作为分类 使得判别器成为一个判断二分类问题的模型 实际上判别器也是做一个二分类任务 判断是否为原图输出0或1
 def forward(self,input):
 return self.main(input).view(-1)# 转成一个列向量 即sigmoid的结果在更前面的维度

我个人觉得 这个GAN中的下采样再上采样的过程应该也借鉴了U-Net的网络结构。不过这里是将噪音编码为图像 再将图像解码为一个score。

模型读取
def model(device,pth False):
 netg NetG(Config).to(device)
 netd NetD(Config).to(device)
 if pth:
 netg.load_state_dict(torch.load(Config.load_G))
 netd.load_state_dict(torch.load(Config.load_D))
 return netg,netd
class Config():
 data_path data/ 
 num_workers 4
 image_size 96# 输入和输出的图片尺寸
 batch_size 64
 max_epoch 4000
 lr_G 2e-4# 生成器的学习率
 lr_D 2e-4# 判别器的学习率
 beta1 0.5# Adam优化器的beta1参数
 nz 100# 产生的噪音维度
 ngf 64# 生成器feature map数
 ndf 64# 判别器feature map数
 save_img_path generate_img # 生成的图片保存路径
 save_model_G_path ppppth/G 
 save_model_D_path ppppth/D 
 load_D ppppth/D/Anime_GAN_Dlast.pth 
 load_G ppppth/G/Anime_GAN_Glast.pth 
 vis True# 是否使用可视化
 env GAN 
 plot_every 20# 每间隔20 batch visdom画图一次
 d_every 1# 每一个batch训练一次判别器
 g_every 5# 每五个batch训练一次生成器
 save_every 20# 每20个epoch保存一次模型
 # 只测试不训练
 gen_img imgs/generate_head/result.png # 从512张生成的图片中保存最好的64张
 gen_num 64
 gen_search_num 512
 gen_mean 0 # 噪声的均值
 gen_std 1 # 噪声的方差

这里的dataset我使用的是torch自带的ImageFolder

import os
import torch
import visdom
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms as T
from model import NetD,NetG
from config import Config
from visualize import Visualizer
def train():
 # data
 data_path Config.data_path
 image_size Config.image_size
 transform transforms()
 batch_size Config.batch_size
 vis Visualizer(Config.env)
 datasets torchvision.datasets.ImageFolder(data_path,transform transform)# 使用这个ImageFolder时 图片的路径必须是所处文件夹的上一级路径 即是data/而不是data/faces/
 dataloader DataLoader(
 datasets,
 batch_size batch_size,
 shuffle True,
 num_workers Config.num_workers,
 drop_last True
 # model
 device cuda 
 G,D model(device device)
 # 优化器和损失函数
 lr_G Config.lr_G# 生成器学习率
 lr_D Config.lr_D# 判别器学习率
 beta Config.beta1
 optimizer_G torch.optim.Adam(G.parameters(),lr lr_G,betas (beta,0.999))
 optimizer_D torch.optim.Adam(D.parameters(),lr lr_D,betas (beta,0.999))
 criterion torch.nn.BCELoss().to(device)# 因为最终是一个二分类的问题
 # 标签 0为假图片 1为真图片
 t_label torch.ones(batch_size).to(device)
 f_label torch.zeros(batch_size).to(device)
 # 噪音 用于生成图片
 noise torch.randn(batch_size,Config.nz,1,1).to(device)# 1x1大小的噪音
 val_noise torch.randn(batch_size,Config.nz,1,1).to(device)
 epochs Config.max_epoch
 loss_add_g torch.tensor(0.0,device device)
 loss_add_d torch.tensor(0.0,device device)

transform我没做什么特殊的数据增强

def transforms():
 transforms T.Compose([
 T.CenterCrop(Config.image_size),
 T.ToTensor(),
 T.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
 return transforms

 

开始训练
 for epoch in range(epochs):
 for i,(img,_) in enumerate(dataloader):
 img img.to(device)
 # 致力于让生成的假图骗过判别器
 if (i 1) % Config.g_every 0:
 # 训练生成器
 optimizer_G.zero_grad()
 noise.data.copy_(torch.randn(batch_size,Config.nz,1,1))# 使每次训练生成器时 噪音不同
 fake_img G(noise)
 out D(fake_img)
 loss criterion(out,t_label)
 loss_add_g loss
 loss.backward()
 optimizer_G.step()
 loss_G_mean loss_add_g / (i 1)
 # 致力于让判别器能识别出真图和假图
 if (i 1) % Config.d_every 0:
 # 训练判别器,训练判别器要训练两部分
 optimizer_D.zero_grad()
 # 尽可能让判别器识别图片为真
 real_output D(img)
 loss_r criterion(real_output,t_label)# 使判别器尽量识别出源图片是真图片
 loss_r.backward()
 # 尽可能让判别器识别为假
 noise.data.copy_(torch.randn(batch_size,Config.nz,1,1))
 fake_img G(noise)# 根据噪音生成图片
 fake_output D(fake_img)
 loss_f criterion(fake_output,f_label)# 使判别器尽量识别出生成的图片是假的图片
 loss_f.backward()
 loss loss_f loss_r
 optimizer_D.step()
 loss_add_d loss
 loss_D_mean loss_add_d / (i 1)
 # 每隔plot_every个batch在visdom上画一次图
 if Config.vis and i % Config.plot_every Config.plot_every - 1:
 generate_img G(val_noise)
 vis.images(generate_img.detach().cpu().numpy()[:64]*0.5 0.5,win fake )
 vis.images(img.data.cpu().numpy()[:64]*0.5 0.5,win real )
 vis.plot( loss_g ,loss_G_mean.data.cpu().numpy())
 vis.plot( loss_d ,loss_D_mean.data.cpu().numpy())
 loss_add_g torch.tensor(0.0, device device)
 loss_add_d torch.tensor(0.0, device device)
 print( Generators: loss_G {} , Discriminators: loss_D {} .format(loss_G_mean, loss_D_mean))
 if (epoch 1) % Config.save_every 0:
 torch.save(G.state_dict(),os.path.join(Config.save_model_G_path, Anime_GAN_Glast.pth ))
 torch.save(D.state_dict(), os.path.join(Config.save_model_D_path, Anime_GAN_Dlast.pth ))

数据的可视化参考了陈云大佬的代码 写了一个由visdom的实现的模块。

from itertools import chain
import visdom
import torch
import time
import torchvision as tv
import numpy as np
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267213.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号