栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch卷积快速分类框架

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch卷积快速分类框架

文章目录
  • 方便新手分类的一个简单框架
    • Pytorch快速分类微型框架
      • 介绍
        • 框架介绍
        • 环境
        • 需要安装的第三方包
      • 使用说明
      • (可选)在自己的环境上安装此包
    • 例子1:用于LeNet5图像分类
    • 例子2:珠宝分类

方便新手分类的一个简单框架 Pytorch快速分类微型框架

链接:Gitee码云

介绍

基于Pytorch开发的中文快速分类框架,易于修改。

框架用途: 图像分类、模型训练

框架介绍

本项目基于 PytorchModelTools 项目二次抽象重构。

  • 之前的代码太过于糟糕(可移植性差),虽然能用,但是可移植性差,代码耦合高。
环境

理论上什么环境都能使用,因为我的代码并没有用到什么太复杂的操作。

Pytorch 1.8.1
cuda 11.1
Python 3.6
需要安装的第三方包
pip install xlwt tqdm matplotlib numpy pandas
使用说明

框架功能已经和原框架一致,可通过demo/lenet5_cifar10.py和demo/resnet_gemstones.py快速了解。

(可选)在自己的环境上安装此包
python setup.py install
例子1:用于LeNet5图像分类

例子来源于demo/lenet5_cifar10.py,可下载后直接跑通

import torch.nn as nn
from torchvision.datasets import CIFAR10
import torch.optim as optim
from fcf import *
import torchvision.transforms as transforms
from torch.nn import functional as F
from torch.utils.data import DataLoader

frame = baseframe().init_logging().set_seed()
# 环境构建
model = LeNet5()
optimizer = optim.SGD(params=model.parameters(),lr=0.001,momentum=0.9)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
loss_fc = nn.CrossEntropyLoss()
# 构建模型进行训练
frame.build(model, is_cuda=False, device='cuda:0')  # 如果is_cuda是True,那么device设置的才有意义。
frame.config(
    epochs=1000,  # 循环的epoch
    optimizer=optimizer,  # 优化器
    scheduler=scheduler,  # 学习率衰减
    loss_fc=loss_fc,  # 损失函数
    train_dataloader=trainloader,  # 数据集加载
    test_dataloader=testloader,
    is_softmax=False,  # softmax层,默认不需要, 在框架中是特殊用途,自己需要softmax自己加,记得加了softmax不能用交叉熵损失。
    is_half=False,  # 开启半精度训练
    is_two_category=False,  # 是否是二分类,如果是,则会有二分类指标。例如:Recall、特异性、ROC等。
)
frame.train(is_checkpoint=False)  # 是否要保持模型
例子2:珠宝分类

例子来源于:demo/resnet_gemstones.py

import logging

import torch.nn as nn
from torchvision.models import resnet34
import torch
import warnings
import torch.optim as optim

warnings.filterwarnings("ignore")
from dataset.gemstones_dataset import train_dataloader, test_dataloader
from fcf import *

# 环境准备


if __name__ == '__main__':
    # 模型构建
    frame = baseframe().init_logging().set_seed()
    # 环境构建
    model = resnet34(pretrained=True)
    model.fc = nn.Sequential(
        nn.Dropout(),
        nn.Linear(512, 87, bias=False),
    )
    optimizer = optim.RMSprop(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    loss_fc = nn.CrossEntropyLoss()
    # 构建模型进行训练
    frame.build(model, is_cuda=False, device='cuda:0')  # 如果is_cuda是True,那么device设置的才有意义。
    frame.config(
        epochs=1000,  # 循环的epoch
        optimizer=optimizer,  # 优化器
        scheduler=scheduler,  # 学习率衰减
        loss_fc=loss_fc,  # 损失函数
        train_dataloader=train_dataloader,  # 数据集加载
        test_dataloader=test_dataloader,
        is_softmax=False,  # softmax层,默认不需要, 在框架中是特殊用途,自己需要softmax自己加,记得加了softmax不能用交叉熵损失。
        is_half=False,  # 开启半精度训练
        is_two_category=False,  # 是否是二分类,如果是,则会有二分类指标。例如:Recall、特异性、ROC等。
    )
    frame.train(is_checkpoint=False)  # 是否要保持模型
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/300341.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号