栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

SAGAN

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

SAGAN

https://github.com/taki0112/Self-Attention-GAN-Tensorflow

1、SAGAN文件的结构:
--dataset  #数据集文件,需要自己下载
	--celeba
		---img1.jpg
		---img2.jpg
--ops.py #图层文件
--utils.py #操作文件
--SAGAN.py#模型文件
--main.py#主函数文件
2、模型崩塌

模型崩塌是指生成器学习到一种能够欺骗判别器的特征后,所有的学习特征都会向这个特征靠拢。具体表现就是GAN一旦生成一张能够欺骗判别器的图像之后,那么其他特征会与这个特征非常接近,导致最终生成的结果中有一样或者类似的图像,这个现象在DCGAN中非常明显,后面的提出的模型基本很少出现这种现象了。

3、解决模型崩塌:

(1)分批打乱数据
(2)期望值特征匹配
(3)更新历史均值
(4)one-side dlabel smoothing
(5)virtual batch normalization

4、 train
  • python main.py --phase train --dataset celebA --gan_type hinge

总共有4个场景需要训练测试,
自动轮询这几个场景,一个场景训练完再训练另一个场景:

#!/usr/bin

echo 'start'
python main.py --phase train --dataset DJI_0501 --gan_type hinge
wait
python main.py --phase train --dataset Berghouse --gan_type hinge
wait
python main.py --phase train --dataset DJI_0862 --gan_type hinge
wait
python main.py --phase train --dataset Bluemlisalphutte  --gan_type hinge

echo "end"

5、 test
  • python main.py --phase test --dataset celebA --gan_type hinge

测试:四个场景都测试

#!/usr/bin
echo 'start'
python main.py --phase test --dataset DJI_0501 --gan_type hinge
wait
python main.py --phase test --dataset Bluemlisalphutte  --gan_type hinge
wait
python main.py --phase test --dataset DJI_0862  --gan_type hinge
wait
python main.py --phase test --dataset Berghouse --gan_type hinge
echo "end"

6、随机输入测试

测试源码是输入一张随机数据生成10张最好的相似图像

    def test(self):
        import time
        from PIL import Image
        import numpy as np
        from sklearn import preprocessing

        start_Time = time.time()

        self.saver = tf.train.Saver()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        result_dir = os.path.join(self.result_dir, self.model_dir)
        check_folder(result_dir)

        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random condition, random noise """
        #原代码
        for i in range(self.test_num) :
            z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
            samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
            save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
                       [image_frame_dim, image_frame_dim],
                        result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
        end_Time = time.time()
        print("process time % s " % (end_Time - start_Time))
7、指定输入测试

修改测试代码:修改输入为指定图像,生成对应的图像

    def test(self):
        import time
        from PIL import Image
        import numpy as np
        from sklearn import preprocessing

        start_Time = time.time()

        self.saver = tf.train.Saver()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        result_dir = os.path.join(self.result_dir, self.model_dir)
        check_folder(result_dir)

        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random condition, random noise """
        #原代码
        #for i in range(self.test_num) :
        #    z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
        #    samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
        #    save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
        #               [image_frame_dim, image_frame_dim],
        #                result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
        #end_Time = time.time()
        #print("process time % s " % (end_Time - start_Time))
        path_list = os.listdir('./dataset/' + self.dataset_name)
        path_list.sort()
        data_num = len(path_list)

        for filename in path_list:
            print('filenamexxxxxxxxxxxx',filename)
            z_sample=Image.open(os.path.join('./dataset',self.dataset_name,filename))


            z_sample = z_sample.resize((64,64),Image.ANTIALIAS)
            z_sample = np.array(z_sample,dtype='int8')

            z_sample = z_sample.reshape(48, self.z_dim)
            z_sample = preprocessing.scale(z_sample)

            z_sample = z_sample.reshape(48, 1, 1, self.z_dim)

            samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

            save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
                        [image_frame_dim, image_frame_dim],
                        result_dir + '/' + self.model_name + '_test_{}.png'.format(filename))
        end_Time = time.time()
        print("process time % s "% (end_Time-start_Time))
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/675042.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号