import torch from torch.utils.data import Dataset import PIL from PIL import Image import os二、类的属性
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
其中:
root_dir是根目录,可以是绝对路径,也可以是相对路径。
label_dir是标签
self.path=os.path.join(self.root_dir,self.label_dir)的作用是把根目录和标签拼在一起,所形成的就是一个保存有训练所需数据的文件夹的目录。
def __getitem__(self, index):
img_name=self.img_path[index]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
其中:
__getitem__返回的是img和label,img是图像本身,label是标签。
①定义一个数据集:
root_dir="D:\Pycharm\pythonProject7\hymenoptera_data\train" ants_label_dir="ants" bees_label_dir="bees" ants_dataset=MyData(root_dir,ants_label_dir) bees_dataset=MyData(root_dir,bees_label_dir)
②数据集之间可以相加:
train_dataset=ants_dataset+bees_dataset
假设:ants_dataset有124项,则train_dataset[0]~train_dataset[123]都是ants_dataset的数据,123往后才是bees_dataset的数据。
③img=train_dataset[123]与img,label=train_dataset[123]的区别:
img=train_dataset[123]
执行完此语句,img是:
(
,
‘ants’)
此时的img并不是一张图片,而是图片和标签
img,label=train_dataset[123]
而执行完后者,img是:



