https://github.com/taki0112/Self-Attention-GAN-Tensorflow
1、SAGAN文件的结构:--dataset #数据集文件,需要自己下载 --celeba ---img1.jpg ---img2.jpg --ops.py #图层文件 --utils.py #操作文件 --SAGAN.py#模型文件 --main.py#主函数文件2、模型崩塌
模型崩塌是指生成器学习到一种能够欺骗判别器的特征后,所有的学习特征都会向这个特征靠拢。具体表现就是GAN一旦生成一张能够欺骗判别器的图像之后,那么其他特征会与这个特征非常接近,导致最终生成的结果中有一样或者类似的图像,这个现象在DCGAN中非常明显,后面的提出的模型基本很少出现这种现象了。
3、解决模型崩塌:(1)分批打乱数据
(2)期望值特征匹配
(3)更新历史均值
(4)one-side dlabel smoothing
(5)virtual batch normalization
- python main.py --phase train --dataset celebA --gan_type hinge
总共有4个场景需要训练测试,
自动轮询这几个场景,一个场景训练完再训练另一个场景:
#!/usr/bin echo 'start' python main.py --phase train --dataset DJI_0501 --gan_type hinge wait python main.py --phase train --dataset Berghouse --gan_type hinge wait python main.py --phase train --dataset DJI_0862 --gan_type hinge wait python main.py --phase train --dataset Bluemlisalphutte --gan_type hinge echo "end"5、 test
- python main.py --phase test --dataset celebA --gan_type hinge
测试:四个场景都测试
#!/usr/bin echo 'start' python main.py --phase test --dataset DJI_0501 --gan_type hinge wait python main.py --phase test --dataset Bluemlisalphutte --gan_type hinge wait python main.py --phase test --dataset DJI_0862 --gan_type hinge wait python main.py --phase test --dataset Berghouse --gan_type hinge echo "end"6、随机输入测试
测试源码是输入一张随机数据生成10张最好的相似图像
def test(self):
import time
from PIL import Image
import numpy as np
from sklearn import preprocessing
start_Time = time.time()
self.saver = tf.train.Saver()
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
result_dir = os.path.join(self.result_dir, self.model_dir)
check_folder(result_dir)
if could_load:
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
#原代码
for i in range(self.test_num) :
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
[image_frame_dim, image_frame_dim],
result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
end_Time = time.time()
print("process time % s " % (end_Time - start_Time))
7、指定输入测试
修改测试代码:修改输入为指定图像,生成对应的图像
def test(self):
import time
from PIL import Image
import numpy as np
from sklearn import preprocessing
start_Time = time.time()
self.saver = tf.train.Saver()
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
result_dir = os.path.join(self.result_dir, self.model_dir)
check_folder(result_dir)
if could_load:
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
#原代码
#for i in range(self.test_num) :
# z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
# samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
# save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
# [image_frame_dim, image_frame_dim],
# result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
#end_Time = time.time()
#print("process time % s " % (end_Time - start_Time))
path_list = os.listdir('./dataset/' + self.dataset_name)
path_list.sort()
data_num = len(path_list)
for filename in path_list:
print('filenamexxxxxxxxxxxx',filename)
z_sample=Image.open(os.path.join('./dataset',self.dataset_name,filename))
z_sample = z_sample.resize((64,64),Image.ANTIALIAS)
z_sample = np.array(z_sample,dtype='int8')
z_sample = z_sample.reshape(48, self.z_dim)
z_sample = preprocessing.scale(z_sample)
z_sample = z_sample.reshape(48, 1, 1, self.z_dim)
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
[image_frame_dim, image_frame_dim],
result_dir + '/' + self.model_name + '_test_{}.png'.format(filename))
end_Time = time.time()
print("process time % s "% (end_Time-start_Time))



