数据目录结构
代码
from torch.utils.data import Dataset
import os
from PIL import Image
# 使用pytorch对数据集进行操作需要继承Dataset类,重写相关方法
class MyData(Dataset):
# 初始化,root_dir为跟路径,data_dir为跟路径下数据集的名字,该名字为该数据集下所有图片的标签
def __init__(self, root_dir, data_dir):
self.root_dir = root_dir
self.data_dir = data_dir
self.label = data_dir
self.data_path = os.path.join(self.root_dir, self.data_dir)
self.imgs = os.listdir(self.data_path)
# 获取数据集中的某一个数据,重写该方法之后即可直接通过 该类实例[下标] 来访问任意元素
# 该函数返回数据集中的某一张图片和相应标签
def __getitem__(self, item):
image_name = self.imgs[item]
image_path = os.path.join(self.root_dir, self.data_dir, image_name)
image = Image.open(image_path)
return image, self.label
def __len__(self):
return len(self.imgs)
ants= MyData("dataset/train","ants")
ant1, label1 = ants[0]
ant1.show()
参考课程地址:https://www.bilibili.com/video/BV1hE411t7RN?p=7



