栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch实战宝可梦分类-自定义数据集完成宝可梦分类案例分步解析

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch实战宝可梦分类-自定义数据集完成宝可梦分类案例分步解析

Pytorch实战-自定义数据集完成宝可梦分类案例分步解析 前言、准备工作

本案例需要导入的包, 没有下载的通过pip install来下载

部分库的详细安装教程可以看我之前的文章
Visdom的下载与踩坑
pytorch的安装 基于anaconda

import torch
import os
import glob
import random, csv, time
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torchvision
from torchvision import transforms
from torchvision.models import resnet18
import visdom
from PIL import Image
一、数据集介绍

自定义的数据集内容如下

皮卡丘:234超梦:239杰尼龟:223小火龙:238妙蛙种子:234

已经将照片存储至相应的文件夹下, 如下

二、自定义数据集分步解读

Dataset基础文章: Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)

自定义的Dataset大致框架如下, 这方面不太懂的可以看看我之前的文章.

class Pokemon(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, item):
        pass
1.观察数据集

观察一下数据集中的图片,发现图片的类型有jpg,png,jpeg, 并且图片的大小各不相同,因此我们需要对训练的图片做resize等操作

2.类别映射关系构建
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.mode = mode
        pass

root: 数据集所在的根目录resize 数据集中提供的数据统一大小mdoe 读取数据集时的模式 train,val,test

因为在模型中label需要转换为相应的int形, 我希望初始化函数能自动的给出的root路径里的文件夹中读取出name与label的映射关系,这更符合应用中的实际情况
简单的实现映射效果如下代码即可

  ...: dic = {}
  ...: con = 0
  ...: for name in os.listdir(root):
  ...:     dic[name] = con 
  ...:     con += 1
  ...: print(dic)
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}

考虑到实际场景中可能出现的状况,使用下面的代码来构建映射

 self.name2label = {} # "sq...":0
 for name in sorted(os.listdir(os.path.join(root))):
     if not os.path.isdir(os.path.join(root, name)):
         continue

     self.name2label[name] = len(self.name2label.keys())
3. 初始化图片读取

因为这里我们希望得到的是一个包含imgpath,label的对象,所以在第一次运行的时候我们可以自定义函数将这样的关系存储至一个csv文件中

glob 文件名模式匹配: 通过指定的筛选规则返回指定路径下的所有满足规则的文件 并且可以进行迭代

存储images与label的关系

    def load_csv(self, filename):
        # filename指的是csv的名字,这里将映射的csv文件存储在root目录下,如果存在则跳过
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 打乱顺序
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    # images中已经包含了label,这里通过split来读取出来
                    # imges: 如'pokemon\bulbasaur\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 按 'pokemon\bulbasaur\00000000.png', 0 写入csv中
                    writer.writerow([img, label])
                print('writen into csv file successful:', filename)

images.csv的结果如下图所示

从csv中读取映射关系,便于加载数据集,最后返回images与labels,在init函数中存储为类变量

        # read from csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\bulbasaur\00000000.png', 0
                img, label = row
                images.append(img)
                labels.append(int(label))

        assert len(images) == len(labels)
        return images, labels
4.划分数据集

这里按照60,20,20的比例来分割数据集为train,test,val

        if mode=='train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode=='val': # 20% = 60%->80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20% = 80%->100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

5.getitem

对获取到的img和label做相应的变换处理,并转换为tensor对象

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path= > image data
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label
6. 查看自定义数据集的读取效果

这里借助visdom库来可视化查看图片, 因为对于图像数据它可以直接从tensor对象来转化
Visdom的下载使用方法请看本链接
首先在命令行启动 python -m visdom.server
创建test函数来观察自定义数据集的效果, 这里将初始batch设置为32张图片
使用Dataloader来创建一个loader, 其好处是可以指定batch并且可以shuffle数据

def test():
    viz = visdom.Visdom()
    start = time.time()
    db = Pokemon(r'D:SourceDatasetspokeman', 64, 'train')
    loader = DataLoader(db, batch_size=32, shuffle=True)
    for x, y in loader:
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)
    print('time:', time.time() - start)


这里看到的效果很诡异, 这是因为我们在__getitem__函数中添加了transform的操作

注: 其中的mean与std参数来自ImageNet的均值和标准差。使用Imagenet的均值和标准差是一种常见的做法。它们是根据数百万张图像计算得出的


这里影响视觉效果的主要是Normalize操作,因此我们可以写一个函数来起到denormalize的效果

    def denormalize(self, x_hat):

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean

        return x

并且将可视化的对象修改为viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
可以看到正常的图片

二、自定义数据集快速构建法

如果数据集的存放结构比较整齐,类似下图这样


就可以用ImageFolder一行代码来代替所有的步骤, 仅仅需要事先指定一下transform的内容, 这里就简单的做个resize
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名

    tf = transforms.Compose([
                    transforms.Resize((64,64)),
                    transforms.ToTensor(),
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:SourceDatasetspokeman', transform=tf)

结果十分顺利

并且ImageFolder类已经内置好了方法构建出了类别与文件夹名的映射关系, 查看方式如下

print(db.class_to_idx)

{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
三、迁移学习训练分类器(自定义数据类实现)

这里使用resnet18来训练,导入方式如下

from    torchvision.models import resnet18
1.导入并修改resnet18
    trained_model = resnet18(pretrained=True)

这里需要设置参数pretrained=True, 获取已经预训练好的参数
对于resnet的最后一层我们需要手动的做一些修改使其能够适合我们自定义的数据集
因为torch没有提供Flatten层, 这里我们可以手动写一个Flatten类完成拉平的操作, 其核心就是用view函数来修改维度

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

如果不熟悉resnet, 我们可以输出一下前面17层的输出, 来决定如何修改网络
trained_model.children())[:-1]来获取网络的前17层

trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 512, 1, 1]
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          ).to(device)

先输出一下网络结构是怎么样的

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (5): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (6): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (7): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (8): AdaptiveAvgPool2d(output_size=(1, 1))
  (9): Flatten()
)
time: 3.332292318344116

Process finished with exit code 0

这里可以随便创建一个适合的维度来看看输出什么

x = torch.randn(2, 3, 64, 64).to(device)
print(model(x).shape)
torch.Size([2, 512])

所以添加一个线性层即可

    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 512, 1, 1]
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)
2.编写函数计算准确率

本函数不难, 使用total计算当前loader的总长度, 通过torch.eq(pred, y).sum()得到预测中正确的数量, 最终返回准确率

argmax等基本函数不懂的可以看我之前的文章 pytorch常用函数与基本特性总结大全

def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total
3.编写训练函数

这里的函数比较通用

    '''train'''
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    best_acc, best_epoch = 0, 0
    global_step = 0

    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):

        for step, (x, y) in enumerate(train_loader):
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            # print(f'global_step: {global_step}, loss: {loss.item()}')

            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')
            print(f'epoch: {epoch}, val_acc: {val_acc}')
    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)

在第3轮的时候就在验证集上达到了93%的准确率

epoch: 0, val_acc: 0.4334763948497854
epoch: 1, val_acc: 0.7167381974248928
epoch: 2, val_acc: 0.8969957081545065
epoch: 3, val_acc: 0.9356223175965666
epoch: 4, val_acc: 0.9055793991416309
epoch: 5, val_acc: 0.7939914163090128
epoch: 6, val_acc: 0.8540772532188842
epoch: 7, val_acc: 0.9055793991416309
epoch: 8, val_acc: 0.9184549356223176
epoch: 9, val_acc: 0.944206008583691
best acc: 0.944206008583691 best epoch: 9
loaded from ckpt!
test acc: 0.9273504273504274
time: 458.2812337875366

Process finished with exit code 0

loss和acc的变换过程

三、迁移学习训练分类器(ImageFolder数据类实现)

刚刚的实现方式是使用自定义的Pokemon类来实现的, 因为这次数据集的存储方式十分整齐,所以也可以用ImageFolder来实现, 仅仅需要手动划分一下即可
使用random_split函数来划分db, 事先计算一下划分的大小即可

    resize = 224
    tf = transforms.Compose([
        transforms.Resize((int(resize * 1.25), int(resize * 1.25))),
        transforms.RandomRotation(15),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:SourceDatasetspokeman', transform=tf)
    train_size, val_size = int(len(db) * 0.6), int(len(db) * 0.2)
    test_size = len(db) - train_size - val_size
    train_db, val_db, test_db = torch.utils.data.random_split(dataset=db,
                                                              lengths=[train_size, val_size, test_size])
    train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                              num_workers=0)
    val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=0)
    test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=0)
四、测试单张图片

这里需要事先存储一下之前的label和int的映射关系
img_label = {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
并且在网络测试的时候要记得使用model.eval()

# -*- coding: utf-8 -*-
# @Time    : 2022/2/3 14:24
# @Author  : JokerTong
# @File    : test42_宝可梦测试.py
import torch
from torch import nn
from torchvision.models import resnet18
from torchvision import transforms
from PIL import Image
import visdom


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


img_label = {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
                      Flatten(),
                      nn.Linear(512, 5)
                      ).to(device)
model.load_state_dict(torch.load('best.mdl'))
resize = 224
tf = transforms.Compose([
    transforms.Resize((int(resize * 1.25), int(resize * 1.25))),
    transforms.RandomRotation(15),
    transforms.CenterCrop(resize),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
img = Image.open('test_pikaqiu.jpg')
img_tensor = tf(img)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
model.eval()
out = model(img_tensor)
predict = list(img_label.keys())[torch.argmax(out).item()]
viz = visdom.Visdom()
viz.images(transforms.ToTensor()(img), win='image', opts=dict(title='image'))
viz.text('预测结果:' + predict, win='predict', opts=dict(title='predict'))
print(predict)

Setting up a new session...
pikachu

Process finished with exit code 0

可视化结果如下

全代码
# -*- coding: utf-8 -*-
# @Time    : 2022/2/1 11:36
# @Author  : JokerTong
# @File    : test41_自定义数据集.py
import torch
import os
import glob
import random, csv, time
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torchvision
from torchvision import transforms
from torchvision.models import resnet18
import visdom
from PIL import Image


class Pokemon(Dataset):
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.mode = mode
        self.name2label = {}  # "sq...":0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue

            self.name2label[name] = len(self.name2label.keys())
        # print('name2label create success!', self.name2label)
        # image, label
        self.images, self.labels = self.load_csv('images.csv')
        # print('load images.csv success!', len(self.images))
        print(len(self.images))
        if mode == 'train':  # 60%
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 20% = 60%->80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # 20% = 80%->100%
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        # filename指的是csv的名字,这里将映射的csv文件存储在root目录下
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 打乱顺序
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    # images中已经包含了label,这里通过split来读取出来
                    # imges: 如'pokemon\bulbasaur\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 按 'pokemon\bulbasaur\00000000.png', 0 写入csv中
                    writer.writerow([img, label])
                print('writen into csv file successful:', filename)

        # read from csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\bulbasaur\00000000.png', 0
                img, label = row
                images.append(img)
                labels.append(int(label))

        assert len(images) == len(labels)
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path= > image data
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label

    def denormalize(self, x_hat):

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean

        return x


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


def test():
    # db = Pokemon(r'D:SourceDatasetspokeman', 128, 'train')
    tf = transforms.Compose([
        transforms.Resize((128, 120)),
        transforms.ToTensor(),
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:SourceDatasetspokeman', transform=tf)
    print(db.class_to_idx)
    loader = DataLoader(db, batch_size=32, shuffle=True)
    for x, y in loader:
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        # time.sleep(10)


def evalute(model, loader):
    model.eval()

    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total


if __name__ == '__main__':
    '''init'''
    start = time.time()
    viz = visdom.Visdom()
    # test()
    batchsz = 32
    lr = 1e-3
    epochs = 10
    device = torch.device('cuda')
    torch.manual_seed(1234)
    '''自定义数据集'''
    # # 获取数据集
    # train_db = Pokemon(r'D:SourceDatasetspokeman', 224, mode='train')
    # val_db = Pokemon(r'D:SourceDatasetspokeman', 224, mode='val')
    # test_db = Pokemon(r'D:SourceDatasetspokeman', 224, mode='test')
    # # 创建loader对象
    # train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
    #                           num_workers=0)
    # val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=0)
    # test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=0)
    '''ImageFolder数据集'''
    resize = 224
    tf = transforms.Compose([
        transforms.Resize((int(resize * 1.25), int(resize * 1.25))),
        transforms.RandomRotation(15),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:SourceDatasetspokeman', transform=tf)
    train_size, val_size = int(len(db) * 0.6), int(len(db) * 0.2)
    test_size = len(db) - train_size - val_size
    train_db, val_db, test_db = torch.utils.data.random_split(dataset=db,
                                                              lengths=[train_size, val_size, test_size])
    train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                              num_workers=0)
    val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=0)
    test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=0)

    # 创建resnet18
    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 512, 1, 1]
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)
    # x = torch.randn(2, 3, 224, 224).to(device)
    # x = torch.randn(2, 3, 64, 64).to(device)
    # print(model)
    # print(model(x).shape)
    '''train'''
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    best_acc, best_epoch = 0, 0
    global_step = 0

    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):

        for step, (x, y) in enumerate(train_loader):
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            # print(f'global_step: {global_step}, loss: {loss.item()}')

            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')
            print(f'epoch: {epoch}, val_acc: {val_acc}')
    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)
    print('time:', time.time() - start)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/738722.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号