栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch框架下的Mnist手写数字识别demo

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch框架下的Mnist手写数字识别demo

pytorch框架下的Mnist手写数字识别demo

数据集下载地址:http://deeplearning.net/data/mnist/

# -- 构造网络 --
# mnist数据集50000个样本,每个样本28*28=784个像素点
# 输入数据1*784,设计三层网络,第一层784*128,第二层128*256,第三层256*10
from torch import nn
import torch.nn.functional as F


class Mnist_NN(nn.Module):  # 继承父类
    def __init__(self):
        super().__init__()  # 继承父类的构造函数
        self.hidden1 = nn.Linear(784, 128)  # 隐层1 784*128
        self.hidden2 = nn.Linear(128, 256)  # 隐层2 128*256
        self.out = nn.Linear(256, 10)  # 输出层

    def forward(self, x):  # 前向传播
        x = F.relu(self.hidden1(x))  # 输入-隐层1,激活函数relu
        x = F.relu(self.hidden2(x))  # 隐层1-隐层2,激活函数relu
        x = self.out(x)  # 隐层2-输出
        return x  # 返回前向传播计算结果


# -- 构造数据集 --
import gzip
import pickle
import torch
from torch.utils.data import TensorDataset  # 用于创建数据集
from torch.utils.data import DataLoader  # 用于加载数据集
# 使用TensorDataset和DataLoader创建的数据集可以根据传入的batch自动抽样
# 也可自动在每次分组时洗牌

with gzip.open('data/mnist/mnist.pkl.gz', 'rb') as f:  # 读取数据
    ((x_train, y_train), (x_test, y_test), _) = pickle.load(f, encoding='latin-1')
x_train, y_train, x_test, y_test = map(  # 将数据类型转换为tensor
    torch.tensor, (x_train, y_train, x_test, y_test)
)

train_dataset = TensorDataset(x_train, y_train)  # 训练数据集
test_dataset = TensorDataset(x_test, y_test)  # 测试数据集


def getData(train_dataset, test_dataset, batch_size):  # 加载数据集
    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
        DataLoader(test_dataset, batch_size=batch_size * 2),
    )


# -- 训练 --
import numpy as np
from torch import optim

loss_func = F.cross_entropy  # 损失函数,直接从functional中调用交叉熵函数


def getModel():  # 获取实例化模型和优化器
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)


def loss_batch(model, loss_func, x_bath, y_bath, opt=None):
    loss = loss_func(model(x_bath), y_bath)
    # 有优化器,即训练,需要进行更新参数等操作
    # 无优化器,即测试,只求损失值即可
    if opt is not None:
        loss.backward()  # 反向传播
        opt.step()  # 更新参数
        opt.zero_grad()  # 梯度清零

    return loss.item(), len(x_bath)


def mnist(steps, model, loss_func, opt, train_data, test_data):
    # steps 迭代多少次
    # model 网络的实例
    # loss_func 损失函数
    # opt 优化器
    # train_data 训练数据
    # test_data 测试数据
    for step in range(steps):
        for x_bath, y_bath in train_data:  # 训练
            loss_batch(model, loss_func, x_bath, y_bath, opt)

        with torch.no_grad():  # 测试,不更新参数
            losses, nums = zip(
                *[loss_batch(model, loss_func, x_bath, y_bath) for x_bath, y_bath in test_data]
            )
            val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
            print('当前step:' + str(step), '验证集损失:' + str(val_loss))


训练:

train_data, test_data = getData(train_dataset, test_dataset, 64)
model, opt = getModel()
mnist(25, model, loss_func, opt, train_data, test_data)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/739629.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号