栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

手写字体识别(3) 训练及测试

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

手写字体识别(3) 训练及测试

目录
  • 训练
    • 训练代码
    • nn.NLLLoss()与nn.CrossEntropyLoss()的区别
  • 测试
    • 混淆矩阵(Confusion Matrix)
    • 验证代码

github地址:
https://github.com/Huyf9/mnist_pytorch/

训练

在训练之前,需要定义以下几个参数

DEVICE  #设备
BATCH_SIZE  #批量数
CRITERION  #损失函数
LR  #学习率
OPTIMIZER  #优化器
EPOCHS  #训练轮数

我的训练参数设置如下:

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 64
CRITERION  = nn.NLLLoss()
LR = 0.001
OPTIMIZER  = torch.optim.SGD(net.parameters(), lr=LR)
EPOCHS = 200
训练代码

以卷积网络模型为例,训练代码如下:

import torch
import torch.nn as nn
from MnistDataset import Mydataset
from torch.utils.data import DataLoader
from model import ConvNet
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(64)  # 设置一个随机种子,保证每次训练的结果一样

train_path = ['train.txt', 'tr_label.txt']
val_path = ['val.txt', 'val_label.txt']
train_dataset = Mydataset(train_path[0], train_path[1], device)
val_dataset = Mydataset(val_path[0], val_path[1], device)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=True, drop_last=True, num_workers=4)

net = ConvNet().to(device)
CRITERION = nn.NLLLoss()
LR = 0.001
OPTIMIZER = torch.optim.SGD(net.parameters(), lr=LR)
EPOCHS = 100

def train(epoch, epoch_loss, batch_num):

    for i, (pic, label) in tqdm(enumerate(train_dataloader)):

        batch_num += 1

        net.zero_grad()
        out = net(pic)
        # print(out)
        loss_value = loss(out, label)
        epoch_loss += loss_value

        loss_value.backward()
        optimizer.step()

    print(f'epoch: {epoch}ttrain_loss: {epoch_loss/batch_num}')

def val(epoch, epoch_loss, batch_num):
    for i, (pic, label) in tqdm(enumerate(val_dataloader)):
        epoch_num += 1
        out = net(pic)
        loss_value = loss(out, label)
        epoch_loss += loss_value

    print(f'epoch: {epoch}tval_loss: {epoch_loss/batch_num}')


def main():
    for epoch in range(EPOCHS):
        epoch_loss, batch_num = 0, 0
        train(epoch, epoch_loss, batch_num)
        val(epoch, epoch_loss, batch_num)

        if (epoch+1) % 10 == 0:
            torch.save(net.state_dict(), f'model_parameter\parameter_epo{epoch}.pth')


if __name__ == '__main__':
    main()

这里我利用epoch_loss来累加每一个批次的损失函数,再用batch_num来记录每一轮的批次数,最后相除作为这一轮的平均损失。

nn.NLLLoss()与nn.CrossEntropyLoss()的区别

C r o s s E n t r o p y L o s s ( ) = N L L L o s s ( ) + L o g S o f t m a x ( ) CrossEntropyLoss() = NLLLoss() + LogSoftmax() CrossEntropyLoss()=NLLLoss()+LogSoftmax()
由于我们在构建网络的时候在最后一层加上了nn.LogSoftmax(),因此在定义损失函数时我们采用NLLLoss()。

测试

验证部分我们利用训练好的训练模型,将图片输入进去,返回一个1行10列的向量表示数字0-9的概率,我们取最大概率的索引表示模型判断的数字。我们将其保存在混淆矩阵中

混淆矩阵(Confusion Matrix)

混淆矩阵表示分类模型的预测值与真实值的对比情况。以二分类的混淆矩阵为例:

PositiveNegative
TrueTPTN
FalseFPFN

TP表示正确被预测为正例的数量。
TN表示正确被预测为负例的数量。
FP表示错误预测为正例的数量。
FN表示错误被预测为负例的数量。

这里我们引入三个评价指标来度量一个模型的预测能力:查准率(Percision)、查全率(Recall)、F1。
P e r c i s i o n = T P / ( T P + F P ) Percision = TP / (TP+FP) Percision=TP/(TP+FP)

表示分类模型预测为正例的样本中,真正为正例的样本比重。

R e c a l l = T P / ( T P + F N ) Recall = TP / (TP+FN) Recall=TP/(TP+FN)

表示分类模型预测为正例的样本占总正例样本的比重

一般来说,Percision与Recall为一对矛盾的指标,我们一般不会指定某一个指标衡量模型性能,因此我们需要协调两种指标的值,这样就需要引入一个评价指标F1:
F 1 = 2 ∗ P e r c i s i o n ∗ R e c a l l P e r c i s i o n + R e c a l l F1 = 2*{{Percision*Recall}over{Percision+Recall}} F1=2∗Percision+RecallPercision∗Recall​

验证代码

验证代码如下

import torch
import seaborn as sn
from matplotlib import pyplot as plt
from model import ConvNet
from MnistDataset import Mydataset
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd


torch.manual_seed(13)

def get_score(confusion_mat):
    smooth = 0.0001  #防止出现除数为0而加上一个很小的数
    tp = np.diagonal(confusion_mat)
    fp = np.sum(confusion_mat, axis=0)
    fn = np.sum(confusion_mat, axis=1)
    precision = tp / (fp + smooth)
    recall = tp / (fn + smooth)
    f1 = 2 * precision * recall / (precision + recall + smooth)
    return precision, recall, f1

def get_confusion(confusion_matrix, out, label):
    idx = np.argmax(out.detach().numpy())
    confusion_matrix[idx, label] += 1
    return confusion_matrix

def main():
    confusion_matrix = np.zeros((10, 10))

    net = ConvNet()
    net.load_state_dict(torch.load('model_parameter\parameter_epo90.pth'))
    test_path = ['test.txt', r'dataset/test_label.txt']
    test_dataset = Mydataset(test_path[0], test_path[1], 'cpu')
    test_dataloader = DataLoader(test_dataset, 1, True)
    for i, (pic, label) in enumerate(test_dataloader):
        out = net(pic)
        confusion_matrix = get_confusion(confusion_matrix, out, label)

    precision, recall, f1 = get_score(confusion_matrix)
    print(f'precision: {np.average(precision)}trecall: {np.average(recall)}tf1: {np.average(f1)}')
    confusion_mat = pd.DataFrame(confusion_matrix)
    confusion_df = pd.DataFrame(confusion_mat, index=[i for i in range(10)], columns=[i for i in range(10)])
    sn.heatmap(data=confusion_df, cmap='RdBu_r')
    plt.show()
    confusion_df.to_csv(r'confusion.csv', encoding='ANSI')


if __name__ == '__main__':
    main()
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/1018219.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号