栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

tensorflow1.2.1瀵瑰簲鐨刱eras_tensorflow2.3瀵瑰簲keras?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

tensorflow1.2.1瀵瑰簲鐨刱eras_tensorflow2.3瀵瑰簲keras?

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
# import tensorflow_datasets as tfds

# 制作数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255., x_test / 255.
# 要多加一维通道数才能训练
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
# 把标签转化成独热编码
y_train = np.float32(tf.keras.utils.to_categorical(y_train, num_classes=10))
y_test = np.float32(tf.keras.utils.to_categorical(y_test, num_classes=10))

# 超参数
batch_size = 64
epoch = 20

# 制作dataset,并且将训练集打乱
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size).shuffle(batch_size * 10)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

# 利用keras架构建立CNN模型
model = keras.models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', padding='SAME', input_shape=(28, 28, 1)),
    layers.Conv2D(64, (3, 3), activation='relu', padding='SAME'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu', padding='SAME'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(256, (3, 3), activation='relu', padding='SAME'),
    layers.Flatten(),   # 展平,准备全连接
    layers.Dense(512, activation='relu'),
    layers.Dropout(0.2),  # 丢掉一些 抑制过拟合
    layers.Dense(10, activation='softmax')  # 多分类独热用softmax
])

# 打印模型
print(model.summary())
model.compile(optimizer=tf.optimizers.Adam(1e-3),
              loss=tf.losses.categorical_crossentropy,  # 交叉熵损失函数
              metrics=['accuracy'])

# 输入数据进行训练并验证
history = model.fit(train_dataset, epochs=epoch, validation_data=test_dataset)

# 可视化训练结果 绘制loss和acc曲线图
plt.figure(1, figsize=(10, 8))
plt.suptitle('the figure of ACCURACY and LOSS')
plt.subplot(2, 1, 1)
# plt.title('ACCURACY')
plt.plot(history.history['accuracy'], label='train_accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.xlim([0, epoch])
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.ylabel('Accuracy')
plt.ylim([0.95, 1])
plt.legend(loc='lower right')

plt.subplot(2, 1, 2)
# plt.title('LOSS')
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.xlim([0, epoch])
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.ylabel('Loss')
plt.ylim([0, 0.15])
plt.legend(loc='upper right')

plt.savefig("result")
plt.show()

# 打印最终结果
test_loss, test_acc = model.evaluate(test_dataset, verbose=2)
print('nTest accuracy:', test_acc)

最后结果

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/783030.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号