栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

《Pytorch学习指南》- Dataset和Dataloader用法详解

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

《Pytorch学习指南》- Dataset和Dataloader用法详解

目录
      • 前言
      • DataSet
      • DataLoader
      • 数据构建
        • 1. 创建Dataset 类 :sparkles:
        • 2. 读取数据 :ambulance:
        • 3. 返回数据 :zap:
        • 读取数据 :art:
      • 注意细节 :rocket:
      • 对比实验

前言

本章节主要介绍如何使用torch.utils.data 中的Dataset和Dataloader来构建数据集, 重点要看使用细节

DataSet
  • torch.utils.data.Dataset
    • 功能 : Dataset抽象;类, 所有自定义的Dataset都需要继承他, 并重写相应的方法
    • getitem(self, index)
      1. 接收一个索引, 返回一个样本 : index => label, data
      2. 返回的样本的大小要一样
DataLoader
  • torch.utils.data.DataLoader
    • 功能 : 创建可以迭代的数据装载器
    • 参数 :
      1. dataset : Dataset类对象, 决定数据从哪读取以及如何读取
      2. batchsize: 决定数据批次大小
      3. num_works: 多进程读取数据的线程数
      4. shuffle: 每个 epoch 是否乱序
      5. 当样本数不能被batchsize整除时, 是否舍去最后一个batch的数据
    • 名词解释 :
      1. 样本总数 : 80, batchsize : 8 => 1 Epoch = 10 iteration
数据构建 1. 创建Dataset 类 ✨
class WeiBoDataset(Dataset):
	pass
2. 读取数据 

注意 : 我们一般会在初始化的时候就加载进数据, 读取数据函数需要自定义

class WeiBoDataset(Dataset):

    def __init__(self, data_path):
        # 读取数据
        self.label, self.data = self.read_data(data_path)
3. 返回数据 ⚡️
  • 这里需要注意的是, len 是必须要设置的, 返回的是你数据集的大小
  • 根据返回的len来构建索引, 然后把构建好的索引传入__getitem__里
  • getitem 根据传进来的索引获取对应的数据, 可以在这个方法里对数据进行处理
class WeiBoDataset(Dataset):

    def __init__(self, data_path):
        # 读取数据
        self.label, self.data = self.read_data(data_path)

    def __len__(self):
        """
            这个必须要设置, getitem中的index就是根据这个来设置的
        :return:
        """
        return len(self.data)

    def __getitem__(self, index):
        label = 1
        # features = [str(i) for i in range(10)]
        features = np.array([i for i in range(10)])
        return label, features
读取数据 
weibo_dataset=WeiBoDataset("../../datasets/weibo_test_data.csv)
dataloader=DataLoader(weibo_dataset,batch_size=1024,shuffle=True)
for i, batch in enumerate(dataloader):
	# batch : [label, features] 组成
    print(type(batch[0]), type(batch[1]))
注意细节 
  1. 先获取数据集的大小 len
  2. 根据len生成index, 然后shuffle
  3. 根据shuffle后的数据以及batch_size生成索引列表batch_index, 索引列表的大小为 batch_size
  4. 获取每个batch的数据时, 根据batch_index传入到 getitem 获取对应的数据
  5. 注意 : batch的数据类型取决于__getitem__返回的类型, 一般都会转换为tensor
  6. 有的数据类型是无法转换为tensor的, 比如 元素类型为str的list
  7. default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists found
  8. 上面报错原因就是 因为数据无法转换为 tensor , 而类型又不属于 tensors, numpy arrays, numbers, dicts or lists 这几种
  9. 如果返回的数据是集合类型, 可以直接使用 np.array() 转换为ndarray类型, 这样会被自动转换为tensor, 当然要求这个集合类型的元素类型是tensor有的
  10. 如果是tensor没有的,比如 str 类型的, 反而会报错, 比如 7. 报错
对比实验

注意 features的元素类型是str, 那么可以看到下面的输出结果中 label 是 tensor, features 是 list类型的

def __getitem__(self, index):
	label = 1
	# 转换为 ndarray 会报错
	# features = np.array([str(i) for i in range(10)]) 
    features = [str(i) for i in range(10)]
    return label, features
 
 
 
 

下面将feature中的数据元素换成了int类型的, 并且对将list转换为ndarray, 这样在获取batch时数据会自动转换为tensor , 但是这里需要注意的是, 上面的数据是不能用np.array()的, 这是因为 batch 必须包含 tensors, numpy arrays, numbers, dicts or lists 这几种类型, 其他的都会报错, 具体可以查看

def __getitem__(self, index):
	label = 1
    features = np.array([i for i in range(10)])
    return label, features
 
 
 
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/349848.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号