栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch入门:LeNet手写字体识别案例

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch入门:LeNet手写字体识别案例

# 1 加载必要的库
# 2 定义超参数
# 3 构建pipeline(transforms),对图像进行处理
# 4 下载,加载数据集(MNIST)
# 5 创建网络模型
# 6 定义优化器
# 7 定义训练方法
# 8 定义测试方法
# 9 调用训练,测试方法,并且输出结果

# 1 加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

# 2 定义超参数
BATCH_SIZE = 64
DEVICE = torch.device("cpu")
EPOCHS = 10

# 3 构建pipeline(transforms),对图像进行处理
pipeline = transforms.Compose([
    transforms.ToTensor(),                        #变为Tensor类型
    transforms.Normalize((0.1307,), (0.3081,))    #归一化
])


# 4 下载,加载数据集(MNIST)
from torch.utils.data import DataLoader

train_set = datasets.MNIST("data", train=True, download=True, transform=pipeline)
test_set = datasets.MNIST("data", train=False, download=True, transform=pipeline)

train_loader =  (train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)


# 5 创建网络模型
class Digit(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.conv2 = nn.Conv2d(10, 20, 3)
        self.fc1 = nn.Linear(20 * 10 * 10, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        input_size = x.size(0)
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = x.view(input_size, -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


# 6 定义优化器

model = Digit().to(DEVICE)

optimizer = optim.Adam(model.parameters())


# 7 定义训练方法
def train_model(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_index, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_index % 3000 == 0:
            print("epoch:{}t loss:{:.6f}".format(epoch, loss.item()))
         #  print(len(train_loader.dataset))
         #  print(batch_index)


# 8 定义测试方法
def test_model(model, device, test_loader):
    model.eval()
    correct = 0.0
    test_loss = 0.0
    with torch.no_grad():#不需要计算梯度
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()
            # pred = output.max(1,keepdim=True)[1]
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)

        #print(len(test_loader.dataset))
        print("Test Average loss : {:.4f},Accuracy : {:.3f}n".format(
            test_loss, 100.0 * correct / len(test_loader.dataset)))


# 9 调用训练,测试方法,并且输出结果
for epoch in range(1, EPOCHS + 1):
    train_model(model, DEVICE, train_loader, optimizer, epoch)
    test_model(model, DEVICE, test_loader)

---------------------------------------------------------------------------------------------------

将以上代码复制到PyChorm中,便可以直接使用CPU进行运行。

PS:最近昇腾CANN训练营正在进行中,这次训练营包含了模型营、算子营和应用营。基本上包含了华为昇腾AI全栈全流程的软硬件知识,欢迎感兴趣的小伙伴报名参加!

报名地址:昇腾CANN训练营第三期_开发者-华为云

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/308700.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号