栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

在MMClassification中使用Swin-Transformer开始一个分类任务

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

在MMClassification中使用Swin-Transformer开始一个分类任务

最近, Swin Transformer成为ICCV2021的 best paper。作为基础模型,其在分类、检测与分割等下游任务上都取得了SOTA的结果。MMClassification(MMCls)是一个开源图像分类工具箱,是OpenMMLab开源算法库的成员之一。本文主要介绍在MMCls中使用Swin-Transformer开始一个分类任务,具体代码下载:谷歌网盘。 

相关教程文档:欢迎来到 MMClassification 中文教程! — MMClassification 0.16.0 文档https://mmclassification.readthedocs.io/zh_CN/latest/

目录

1. MMClassification安装

2. 数据集准备

3. 使用MMCls做模型微调

准备修改配置文件

训练

测试模型

可视化结果


 

1. MMClassification安装

在使用 MMClassification 之前,我们需要配置环境,步骤如下:

  • 安装 Python, CUDA, C/C++ compiler 和 git
  • 安装 PyTorch (CUDA 版)
  • 安装 MMCV
  • 克隆 MMCls github 代码库然后安装

安装python、cuda、torch等可以参考链接以及网上的教程。安装完后:

检查 nvcc 版本

nvcc -V

检查gcc版本

gcc --version

检查torch版本

pip list | grep "torch"

安装MMCV :

MMCV 是 OpenMMLab 代码库的基础库。Linux 环境的安装 whl 包已经提前打包好,大家可以直接使用pip下载安装,格式如下:

pip install mmcv -f https://download.openmmlab.com/mmcv/dist/{CUDA_v}/{Torch_v}/index.html

需要注意 PyTorch 和 CUDA 版本,确保能够正常安装。

在前面的步骤中,我们输出了环境中 CUDA 和 PyTorch 的版本,分别是 11.1 和 1.9.0,我们需要选择相应的 MMCV 版本。

另外,也可以安装完整版的 MMCV-full,它包含所有的特性以及丰富的开箱即用的 CUDA 算子。完整版本可能需要更长时间来编译。

pip install mmcv -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
# pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html

安装MMCls:

python setup.py install:   安装后包比较稳定,如果需要修改代码,则需要在修改后重新安装才能生效。
python setup.py develop:  安装后包需要不断修改,不需要重新安装,修改的代码就能生效。

git clone https://github.com/open-mmlab/mmclassification
cd mmclassification
python setup.py develop    # 以开发者模式安装
# python setup.py install  # 以普通模式安装

2. 数据集准备

猫狗分类数据集

这里使用猫狗分类数据集作为例子

# 下载分类数据集文件,在目录 $mmclassification 下。
wget https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0 -O cats_dogs_dataset.zip
mkdir data
unzip -q cats_dogs_dataset.zip -d ./data/

完成下载和解压之后, "Cats and Dogs Dataset" 文件夹下的文件结构如下:

data/cats_dogs_dataset
├── classes.txt
├── test.txt
├── val.txt
├── training_set
│   ├── training_set
│   │   ├── cats
│   │   │   ├── cat.1.jpg
│   │   │   ├── cat.2.jpg
│   │   │   ├── ...
│   │   ├── dogs
│   │   │   ├── dog.2.jpg
│   │   │   ├── dog.3.jpg
│   │   │   ├── ...
├── val_set
│   ├── val_set
│   │   ├── cats
│   │   │   ├── cat.3.jpg
│   │   │   ├── cat.5.jpg
│   │   │   ├── ...
│   │   ├── dogs
│   │   │   ├── dog.1.jpg
│   │   │   ├── dog.6.jpg
│   │   │   ├── ...
├── test_set
│   ├── test_set
│   │   ├── cats
│   │   │   ├── cat.4001.jpg
│   │   │   ├── cat.4002.jpg
│   │   │   ├── ...
│   │   ├── dogs
│   │   │   ├── dog.4001.jpg
│   │   │   ├── dog.4002.jpg
│   │   │   ├── ...

可以通过 shell 命令 `tree data/cats_dogs_dataset` 查看文件结构。

支持新的数据集

MMClassification 要求数据集必须将图像和标签放在同级目录下。有两种方式可以支持自定义数据集。

最简单的方式就是将数据集转换成现有的数据集格式(比如 ImageNet)。另一种方式就是新建一个新的数据集类。细节可以查看 文档.

在这个教程中,为了方便学习,我们已经将 “猫狗分类数据集” 按照 ImageNet 的数据集格式进行了整理。

标准文件包括:

1. 类别列表。每行代表一个类别。第一行 cats 类别标注为 0, 第二行 dogs 类别标注为 1.

cats
dogs


2. 训练/验证/测试标签。
每行包括一个文件名和其相对应的标签。 

    ...
    cats/cat.3769.jpg 0
    cats/cat.882.jpg 0
    ...
    dogs/dog.3881.jpg 1
    dogs/dog.3377.jpg 1
    ...

3. 使用MMCls做模型微调

通过命令行进行模型微调步骤如下:

1. 准备自定义数据集
2. 数据集适配 MMCls 要求
3. 在 py 脚本中修改配置文件
4. 使用命令行工具进行模型微调

第1,2步与之前的介绍一致,我们将会介绍后面2个步骤的内容。

准备修改配置文件

为了能够复用不同配置文件中常用的部分,我们支持多配置文件继承。比如模型微调 swin-transfomer-tiny ,新的配置文件可以通过继承 “configs/_base_/models/swin_transformer/tiny_224.py” 来创建模型的基本结构。 继承 “configs/_base_/datasets/imagenet_bs64_swin_224.py” 来使用之前定义好的数据集。继承 “configs/_base_/schedules/imagenet_bs1024_adamw_swin.py” 来自定义学习率策略。为了能够运行设定的学习率策略,还需要继承  “configs/_base_/default_runtime.py”.

配置文件开头应该显示如下

_base_ = [
    '../_base_/models/swin_transformer/tiny_224.py', '../_base_/datasets/imagenet_bs64_swin_224.py',
    '../_base_/schedules/imagenet_bs1024_adamw_swin.py','../_base_/default_runtime.py'
]

第一,修改模型配置。这个新的配置文件需要根据分类问题的类别来调整模型 head 的 num_classes。预训练模型的权重,除了最后一层线性层,其他的部分一般选择复用。

model = dict(
    backbone=dict(
        init_cfg = dict(
            type='Pretrained', 
            checkpoint="https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth", 
            prefix='backbone')
    ),
    head=dict(
        num_classes=2,
        topk = (1, )
    ),
    train_cfg=dict(augments=[
        dict(type='BatchMixup', alpha=0.8, num_classes=2, prob=0.5),
        dict(type='BatchCutMix', alpha=1.0, num_classes=2, prob=0.5)
    ])
    )

第二是数据配置。注意根据自己GPU的现存大小来调节samples_per_gpu,指定数据集的路径,每一个epoch评估一次

img_norm_cfg = dict(
     mean=[124.508, 116.050, 106.438],
     std=[58.577, 57.310, 57.437],
     to_rgb=True)

data = dict(
    # 每个 gpu 上的 batch size 和 num_workers 设置,根据计算机情况设置
    samples_per_gpu = 32,
    workers_per_gpu=2,
    # 指定训练集路径
    train = dict(
        data_prefix = 'data/cats_dogs_dataset/training_set/training_set',
        classes = 'data/cats_dogs_dataset/classes.txt'
    ),
    # 指定验证集路径
    val = dict(
        data_prefix = 'data/cats_dogs_dataset/val_set/val_set',
        ann_file = 'data/cats_dogs_dataset/val.txt',
        classes = 'data/cats_dogs_dataset/classes.txt'
    ),
    # 指定测试集路径
    test = dict(
        data_prefix = 'data/cats_dogs_dataset/test_set/test_set',
        ann_file = 'data/cats_dogs_dataset/test.txt',
        classes = 'data/cats_dogs_dataset/classes.txt'
    )
)
# 修改评估指标设置
evaluation = dict(interval=1, metric='accuracy', metric_options={'topk': (1, )})

第三是学习率策略。模型微调的策略与默认策略差别很大。微调一般会要求更小的学习率和更少的训练周期。

optimizer = dict(lr=0.00025)

# learning policy
lr_config = dict(
    policy='CosineAnnealing',
    by_epoch=False,
    min_lr_ratio=1e-1,
    warmup='linear',
    warmup_ratio=1e-1,
    warmup_iters=100,
    warmup_by_epoch=False)

runner = dict(max_epochs=2)

最后,运行环境配置。直接使用默认的配置。

将上述继承以及修改内容,保存进“configs/swin_transformer/swin-tiny_cats-dogs.py”中

_base_ = [
    '../_base_/models/swin_transformer/tiny_224.py', '../_base_/datasets/imagenet_bs64_swin_224.py',
    '../_base_/schedules/imagenet_bs1024_adamw_swin.py','../_base_/default_runtime.py'
]

model = dict(....)  # 上述代码框的内容复制过来

img_norm_cfg = dict(...)
data = dict(...)
evaluation = dict(...)

optimizer = dict(...)
lr_config = dict(...)
runner  = dict(...)

查看完整的配置文件信息:

python ./tools/misc/print_config.py ./configs/swin_transformer/swin-tiny_cats-dogs.py

训练

我们使用 tools/train.py 进行模型微调:

python tools/train.py ${CONFIG_FILE} [optional arguments]

如果你希望指定训练过程中相关文件的保存位置,可以增加一个参数 --work_dir ${YOUR_WORK_DIR}.

通过增加参数 --seed ${SEED},设置随机种子以保证结果的可重复性,而参数 --deterministic 则会启用 cudnn 的确定性选项,进一步保证可重复性,但可能降低些许效率。

本文例子训练代码如下:

python tools/train.py 
  configs/swin_transformer/swin-tiny_cats-dogs.py 
  --work-dir work_dirs/swin-tiny_cats-dogs 
  --seed 0 
  --deterministic

测试模型

使用 tools/test.py 对模型进行测试:

python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments] [--out ${RESULT_FILE}]

这里有一些可选参数可以进行配置:

--metrics: 评价方式,这依赖于数据集,比如准确率 acc

--metric-options: 对于评估过程的自定义操作,如 topk=1.

--out: 输出结果的文件名。如果不指定,计算结果不会被保存。支持的格式包括json, pkl 和 yml

本文例子测试代码如下:

python tools/test.py ./configs/swin_transformer/swin-tiny_cats-dogs.py work_dirs/swin-tiny_cats-dogs/latest.pth --metrics=accuracy --metric-options=topk=1

可视化结果

我们用下面的命令进行推理单张图片并可视化计算结果。

python demo/image_demo.py ${Image_Path} ${Config_Path} ${Checkpoint_Path} --device {cuda or cpu}

本文例子代码如下: 

python demo/image_demo.py ./data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg ./configs/swin_transformer/swin-tiny_cats-dogs.py work_dirs/swin-tiny_cats-dogs.py/latest.pth

相关代码以及运行过程可以参考谷歌网盘: https://drive.google.com/file/d/1Z41vYvJkWbMAli81ppUmdIGPMV7bBPY2/view?usp=sharing

相关链接:

作者学术讲座: 胡瀚研究员:Swin Transformer和拥抱Transformer的五个理由 | 自动化所系列学术讲座_哔哩哔哩_bilibili本次报告将介绍一种新的视觉骨干网络Swin Transformer,相比于谷歌主要为图像分类问题设计的ViT网络,Swin Transformer对于各种视觉任务都广泛有效,包括图像分类、检测和分割等等。本次报告还将梳理4年来视觉领域逐渐挖掘Transformer优点的发展脉络,并展开讲述拥抱Transformer的5个理由,希望通过这个报告让听众对于Transformer在视觉中的应用有一个整体https://www.bilibili.com/video/BV1eb4y1k7fj?p=1&share_medium=iphone&share_plat=ios&share_session_id=F9A81F20-92D2-4243-9695-BF8935AD81F9&share_source=COPY&share_tag=s_i×tamp=1634388085&unique_k=HUIs9U

论文地址: https://arxiv.org/abs/2103.14030https://arxiv.org/abs/2103.14030

MMClassification : GitHub - open-mmlab/mmclassification: OpenMMLab Image Classification Toolbox and Benchmarkhttps://github.com/open-mmlab/mmclassification

MMClassification 文档:欢迎来到 MMClassification 中文教程! — MMClassification 0.16.0 文档https://mmclassification.readthedocs.io/zh_CN/latest/

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/342204.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号