栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch 的 Dataloader 如何加载 List 对象?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch 的 Dataloader 如何加载 List 对象?

0、写在前面

 在记录该问题解决方案的时候 也有在 csdn 上搜到某位小伙伴遇到同样的问题 但没有说明原因。那我就记录一下吧。

1、问题

之前看到一份代码 在 __init__() 函数中 加载的每一条数据都是一个列表 List【长度为 len_list】 列表中的每一项是一段经过处理的视频 维度为 [C, T, H, W]。

所以 dataset 中每一条数据的维度应该是 [len_list, C, T, H, W]。

按照以往加载数据的经验 我自然而然地认为用 dataloader 返回的数据维度应该是 [B, len_list, C, T, H, W]。然而 事情不是这样的 实际上用 dataloader 返回的数据维度是 [len_list, B, C, T, H, W]。

2、原因

幸亏同实验室的大神了解过这方面的源码 告诉了我原因

如果 dataset 返回的 sample 是序列 Sequence 类【如 字符串(普通字符串和unicode字符串) 列表和元组】的 那 dataloader 默认把 B batch size 那个维度加在序列里每个 item 的 shape 前面。

相关部分源码

elif isinstance(elem, collections.abc.Sequence):
 # check to make sure that the elements in batch have consistent size
 it iter(batch)
 elem_size len(next(it))
 if not all(len(elem) elem_size for elem in it):
 raise RuntimeError( each element in list of batch should be of equal size )
 transposed zip(*batch)
 return [default_collate(samples) for samples in transposed]

源码地址 https://github.com/pytorch/pytorch/blob/ca666982028f32ddf3606c1d6e45a3a83f274d5d/torch/utils/data/_utils/collate.py#L77

3、解决方案

1、在 __init__() 函数里 先把每一条数据转成 Tensor 而不是直接返回 List。这样用 dataloader 加载数据的维度就是我熟悉的 [B, len_list, C, T, H, W]

2、若直接返回 List 则注意在 __getitem__() 函数里处理数据时 维度是 [len_list, B, C, T, H, W]

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267380.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号