栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【pytorch笔记系列】数据加载

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【pytorch笔记系列】数据加载

数据加载

Pytorch学习(三)定义自己的数据集及加载训练
Pytorch之Dataset与DataLoader–陈亮的博客

【总结】

Dataset类是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示

class MyDataSet(data.Dataset):#需要继承torch.utils.data.Dataset
	def __init__(self, opt, is_train) #初始化文件路径或文件名列表。
	def getitem(self, index): # 编写支持数据集索引的函数
         #1. 从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)
         #2. 预处理数据(例如torchvision.Transform)
         #3. 返回数据对(例如图像和标签)
         #这里需要注意的是,第一步:read one data,是一个data
	def __len__(self):# 数据集大小

getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。而SSD代码中txt文件存储文件名,通过以每一行的形式对应获取xml文件和jpg文件的路径,以元组的形式存储

class VOCDetection(data.Dataset):
    def __init__(self, opt, image_sets, is_train):
        获取txt文件地址
        按行读取文件名
        self.ids.append((img_path, ano_path))
    def __getitem__(self, index):
        img_path, ano_path = self.ids[index]
        ......
        return image, target
    def get_annotations(self, path):
        # 树的形式读取xml文件内容
        return np.array(boxes), np.array(labels)
    def __len__(self):
        return len(self.ids)

由于Dataset数据集的数据量大,可以直接使用utils.data.DataLoader对其进行一系列操作。
其中shuffle用以打乱数据集内数据分布的顺序,让数据随机化,这样可以避免过拟合。

train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)

其实,Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生 一个个batch的样本集合。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/768017.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号