- 参考:《动手学深度学习》(Pytorch)版 3.5 节
- 注:本文是 jupyter notebook 文档转换而来,部分代码可能无法直接复制运行!
文章目录
- 1. 获取数据集
- 2. 读取小批量
- 图像分类数据集中最常用的是手写数字识别数据集MNIST,但大部分模型在MNIST上的分类精度都超过了95%,为了更直观地观察算法之间的差异,本文介绍一个图像内容更加复杂的数据集 Fashion-MNIST,这个数据集难度比 MNIST 高,但是尺寸并不大,只有几十M,没有GPU的电脑也能吃得消
- 该数据集可以利用 torchvision 包来下载和处理,该包包含以下几个核心模块
- torchvision.datasets: 提供加载数据的函数及常用数据集接口;
- torchvision.models: 包含常用的模型结构(含预训练模型),如 AlexNet、VGG、ResNet 等;
- torchvision.transforms: 提供常用的图片变换方法,例如裁剪、旋转等;
- torchvision.utils: 提供其他的一些有用的方法
- 开始介绍前,先导入包
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import numpy as np from IPython import display
-
通过 torchvision.datasets.FashionMNIST 方法获取数据集
mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())
参数说明
-
root 参数指定数据集保存路径
-
train 参数指定获取训练集还是测试集
-
download 参数若设置为 True,则在发现 root 路径下没有数据集时自动从网上下载,若已有数据集则不动作
-
transform = transforms.ToTensor() 使所有数据转换为 Tensor,如果不转换则返回的是 PIL 图片
transforms.ToTensor() 将 “尺寸为 H × W × C H times W times C H×W×C 且数据位于 [ 0 , 255 ] [0, 255] [0,255] 的PIL图片” 或者 “数据类型为 np.uint8 的NumPy数组” 转换为 “尺寸为 C × H × W C times H times W C×H×W 且数据类型为 torch.float32 且位于 [0.0, 1.0] 的Tensor”
注意 transforms.ToTensor() 在内的一些关于图片的函数默认输入为 uint8 类型,如果不是则可能得到不想要的结果,所以如果用 [ 0 , 255 ] [0,255] [0,255] 的像素值表示图片数据,则一律将其类型设置为 uint8,以免不必要的bug
-
-
这里加载的 mnist_train 和 mnist_test 都是 torch.utils.data.Dataset 的子类,一些常用方法如下
print(type(mnist_train)) print(len(mnist_train), len(mnist_test)) # 用 len() 获取该数据集的大小 feature, label = mnist_train[0] # 通过下标来访问任意样本 print(feature.shape, label) # [Channel , Height , Width] label,注意由于数据集中都是灰度图,通道数为 1 ''' torchvision.datasets.mnist.FashionMNIST 60000 10000 torch.Size([1, 28, 28]) 9 '''
-
Fashion-MNIST中一共包括了10个类别,分别为
- t-shirt(T恤)
- trouser(裤子)
- pullover(套衫)
- dress(连衣裙)
- coat(外套)
- sandal(凉鞋)
- shirt(衬衫)
- sneaker(运动鞋)
- bag(包)
- ankle boot(短靴)
使用以下函数将数值标签列表转成相应的文本标签列表
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] -
使用以下函数在一行里绘制多个图像和对应的标签
def show_fashion_mnist(images, labels): display.set_matplotlib_formats('svg') # Use svg format to display plot in jupyter _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show() -
随机显示 10 个样本
X, y = [], [] for i in np.random.randint(0,60000,size = 10).tolist(): X.append(mnist_train[i][0]) y.append(mnist_train[i][1]) show_fashion_mnist(X, get_fashion_mnist_labels(y))这里我遇到一个报错,请参考 ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’,我删除了虚拟环境中的 libiomp5md.dll 解决此问题
-
在实践中,数据读取经常是训练的性能瓶颈,torch.utils 模块提供的 DataLoader 方法允许我们方便地使用多进程来加速数据读取
-
mnist_train 是 torch.utils.data.Dataset 的子类,所以我们可以将其传入 torch.utils.data.DataLoader 来创建一个读取小批量数据样本的DataLoader 实例,在创建时
- 通过参数 num_workers 来指定读取数据的进程数量
- 通过 shuffle 参数指定读取时是否打乱
batch_size = 256 if sys.platform.startswith('win'): # 判断操作系统为 windows num_workers = 4 # 使用 4 个进程同时读取 else: num_workers = 0 # 0表示不用额外的进程来加速读取数据 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) -
查看读取一遍数据的耗时
start = time.time() for X, y in train_iter: continue print('%.2f sec' % (time.time() - start))经测试,我的笔记本电脑在不使用多进程加速时耗时 5.88s,使用后减少到 3.18s



