import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #手写数字相关的数据包
# 载入数据集
mnist input_data.read_data_sets( MNIST_data ,one_hot True) #载入数据 {数据集包路径 把标签转化为只有0和1的形式}
#定义变量 即每个批次的大小
batch_size 100 #一次放100章图片进去
n_batch mnist.train.num_examples // batch_size #计算一共有多少个批次 训练集数量 整除 一个批次大小
#定义两个placeholder
x tf.placeholder(tf.float32,[None,784]) #[行不确定 列为784]
y tf.placeholder(tf.float32,[None,10]) #数字为0-9 则为10
#创建简单的神经网络
W tf.Variable(tf.zeros([784,10])) #权重
b tf.Variable(tf.zeros([10])) #偏置
prediction tf.nn.softmax(tf.matmul(x,W) b) #预测
#定义二次代价函数
# loss tf.reduce_mean(tf.square(y-prediction))
#定义交叉熵代价函数
loss tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels y,logits prediction))
#使用梯度下降法
train_step tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init tf.global_variables_initializer()
#准确数 结果存放在一个布尔型列表中
correct_prediction tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) #比较两个参数大小是否相同 同则返回为true 不同则返回为false argmax() 返回张量中最大的值所在的位置
#求准确率
accuracy tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #cast() 将布尔型转换为32位的浮点型 比方说9个T和1个F 则为9个1 1个0 即准确率为90%
with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(n_batch):
batch_xs,batch_ys mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict {x:batch_xs,y:batch_ys})
acc sess.run(accuracy,feed_dict {x:mnist.test.images,y:mnist.test.labels})
print( Iter str(epoch) ,Testing Accuracy str(acc))
Extracting MNIST_datatrain-images-idx3-ubyte.gz Extracting MNIST_datatrain-labels-idx1-ubyte.gz Extracting MNIST_datat10k-images-idx3-ubyte.gz Extracting MNIST_datat10k-labels-idx1-ubyte.gz Iter0,Testing Accuracy0.8502 Iter1,Testing Accuracy0.8954 Iter2,Testing Accuracy0.9014 Iter3,Testing Accuracy0.9052 Iter4,Testing Accuracy0.9079 Iter5,Testing Accuracy0.91 Iter6,Testing Accuracy0.9115 Iter7,Testing Accuracy0.9132 Iter8,Testing Accuracy0.9152 Iter9,Testing Accuracy0.9159 Iter10,Testing Accuracy0.9167 Iter11,Testing Accuracy0.9181 Iter12,Testing Accuracy0.9189 Iter13,Testing Accuracy0.9192 Iter14,Testing Accuracy0.9205 Iter15,Testing Accuracy0.9202 Iter16,Testing Accuracy0.921 Iter17,Testing Accuracy0.9209 Iter18,Testing Accuracy0.9213 Iter19,Testing Accuracy0.9216 Iter20,Testing Accuracy0.922
点赞 关注 收藏 ➕ 点赞 关注 收藏 ➕ 点赞 关注 收藏 ➕



