- 一、datasets
- 二、DataLoader
- 补充:datasets类的代码
本文为学习笔记,感谢PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
一、datasetsdatasets工具在trochvision中
import torchvision from torchvision import transforms as tf from tensorboardX import SummaryWriter train_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=True,download=True) test_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=False,download=True) print(train_dataset[0]) #(, 6) print(train_dataset[1])#同上,返回一张图和标签组成的元组 print(train_dataset.classes) #查看分类类型,此数据集共10类 writer = SummaryWriter('logs\2') #可视化十张图 for i in range(10): img ,label = train_dataset[i] writer.add_image('10train_img',img,i) writer.close()
参数:
CIFAR10:是数据集的名字
root=’./dataset’:保存路径
transform=tf.ToTensor():对图片的转变方法
train=True:训练or测试数据
download=True:是否检测下载
from torch.utils.data import DataLoader from torchvision import transforms as tf import torchvision test_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=False,download=True) #参数batch_size是取数据集中的一个批量进行打包输出,test_iter中的每个元素都是64张图的合并 test_iter = DataLoader(dataset=test_dataset,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
DataLoader中参数batch_size是取数据集中的一个批量进行打包输出,test_iter中的每个元素都是64张图的合并
参数:
dataset:读取的数据集
batch_size :批量大小
shuffle :序列的所有元素随机排序
num_worker :进程数
drop_last :是否丢弃尾部不足batch_size的数据
#Dataset类的代码 from torch.utils.data import Dataset from PIL import Image import os # F:python_projectdeep_learningtrainants_image 013035.jpg class MyDate(Dataset): def __init__(self,root_dir,label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir,self.label_dir) self.img_path = os.listdir(self.path) def __getitem__(self, idx): img_name = self.img_path[idx] img_value_path = os.path.join(self.root_dir,self.label_dir,img_name) img = Image.open(img_value_path) label = self.label_dir return img,label def __len__(self): return len(self.img_path) root_dir = 'F:python_projectdeep_learning\train' label_dir = 'ants_image' ant = MyDate(root_dir,label_dir) img,label = ant.__getitem__(1) img.show() print(label)



