1. Dataset
提供一种方式去获取数据及其label.
如何获取每一个数据及其label告诉我们总共有多少数据,才能知道迭代多少次
from torch.utils.data import Dataset
help(Dataset)
Dataset??
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
2. Dataloader (加载器)
为后面的网络提供不同的数据形式, 对dataset中的数据进行打包.
2.1 代码实现import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
test_dataset = torchvision.datasets.CIFAR10("./dataset", train=False,
transform= transforms.ToTensor())
test_loader = DataLoader(dataset=test_dataset, batch_size= 64,
shuffle=True,num_workers=4,drop_last=True)
# 测试集第一张图片及target
img, target = test_dataset[0]
print(img.shape)
print(test_dataset.classes[target])
# img 和 target 分别打包
#10 轮
for epoch in range(10):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_image("Epoch:{}".format(epoch),
imgs, step, dataformats="NCHW")
step = step + 1
# print(imgs.shape)
# print(targets)
writer.close()



