栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

自编码器

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

自编码器

# 定义输入、标签。这里标签y_就是输入x x tf.compat.v1.placeholder(tf.float32, [None, 784]) y_ tf.compat.v1.placeholder(tf.float32, [None, 784]) # 计算当前参数在神经网络上的结果 y inference(x) # 定义损失函数 loss_mean tf.reduce_mean(tf.reduce_sum(tf.square(y - y_))) # 将loss_mean加入损失集合。 tf.compat.v1.add_to_collection( losses , loss_mean) # 总损失函数 loss tf.add_n(tf.compat.v1.get_collection( losses )) # 初始速率0.1 后面每训练100次后在学习速率基础上乘以0.96 learning_rate tf.compat.v1.train.exponential_decay(0.9999, global_step, 5000, 0.9, staircase True) # 使用tf.train.GradientDescentOptimizer 优化算法来优化损失函数。 train_step tf.compat.v1.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step global_step) # 加载样本数据 data_feed data_init() # 初始化会话并开始训练过程。 init_var tf.compat.v1.global_variables_initializer() with tf.compat.v1.Session() as sess: sess.run(init_var) for i in range(TRAINING_STEPS): train_batch_x data_batch_set(data_feed, feed_name x_train ) # 标签就是样本x sess.run(train_step, feed_dict {x: train_batch_x, y_: train_batch_x}) if i % 500 0: # loss loss_val sess.run(loss, feed_dict {x: train_batch_x, y_: train_batch_x}) print( After %d training step(s) , loss %f % (i, loss_val)) # test test_batch_x data_batch_set(data_feed, feed_name x_test ) # test_img test_x test_batch_x[0] # 原图 test_x_img test_x * 255 test_x_img np.reshape(test_x_img, (28, 28)) # 重建后的图 # 转换为0-1之间的值 test_y tf.nn.sigmoid(y) test_y sess.run(test_y, feed_dict {x: [test_x]}) test_y_img test_y * 255 test_y_img np.reshape(test_y_img, (28, 28)) plt.subplot(1, 2, 1) plt.title( origin ) plt.imshow(test_x_img) plt.subplot(1, 2, 2) plt.title( forecast ) plt.imshow(test_y_img) plt.show()
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267472.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号