栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

卷积神经网络训练过程_卷积神经网络训练算法?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

卷积神经网络训练过程_卷积神经网络训练算法?

目录

Cifar10数据集介绍

​ 卷积神经网络的搭建

 完整代码


Cifar10数据集介绍

Cifar10数据提供了5万张32*32像素点的十分类彩色图片和标签,用于训练;提供了1万张32*32像素点的十分类彩色图像和标签用于测试。

导入cifar10数据集:

cifar10 = tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()

要想可视化出样本可以如下操作:

也可以打印出一个样本的特征:

 第一个样本的特征发现是32行32列的三通道的三维数组

 也可以打印出训练集的第一张样本标签:

也可以打印出测试集的形状:

 卷积神经网络的搭建

搭建一个一层卷积两层全连接的网络:

 随着网络的逐渐复杂,我们可以使用class类搭建网络。

 完整代码
import numpy as np
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import PySide2
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPooling2D,Dropout,Flatten,Dense
from tensorflow.keras import Model


dirname = os.path.dirname(PySide2.__file__)
plugin_path = os.path.join(dirname, 'plugins', 'platforms')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_path
np.set_printoptions(threshold=np.inf)  # 设置打印出所有参数,不要省略

mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 由于fashion数据集是三维的(60000, 28, 28),而cifar10 数据集是四维的,而此网络是用来识别四维的数据所所以需要将3维的输入扩展维4维的输入
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)


class baseline(Model):
    def __init__(self):
        super(baseline, self).__init__()
        self.c1 = Conv2D(filters=6, kernel_size=(5,5), padding='same')
        self.b1 = BatchNormalization()
        self.a1 = Activation('relu')
        self.p1 = MaxPooling2D(pool_size=(2,2), strides=2, padding='same')
        self.d1 = Dropout(0.2)

        self.flatten = Flatten()
        self.f1 = Dense(128, activation='relu')
        self.d2 = Dropout(0.2)
        self.f2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)
        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y

model = baseline()

model.compile(optimizer='adam',
              loss=tf.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------------------------load the model---------------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=10, epochs=5,
                    validation_data=(x_test, y_test),
                    validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + 'n')
    file.write(str(v.shape) + 'n')
    file.write(str(v.numpy()) + 'n')
file.close()

############################  show  #############################
# 显示训练集和验证集的acc和loss曲线
acc=history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/786704.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号