栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch之深入理解collate

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch之深入理解collate

作用

collate_fn 即用于collate的function 用于整理数据的函数。
说到整理数据 你当然要会用数据 即会用数据制作工具torch.utils.data.Dataset 虽然我们今天谈的是torch.utils.data.DataLoader 但是 其实

这两个你如何定义;从装载器dataloader中取数据后做什么处理;模型的forward()中如何处理。

这三部分都是有机统一而不可分割的 一个地方改了 其他地方就要改。

emmm… 小小总结 collate_fn笼统的说就是用于整理数据 通常我们不需要使用 其应用的情形是 各个数据长度不一样的情况 比如第一张图片大小是28*28,第二张是50*50 这样的话就如果不自己写collate_fn 而使用默认的 就会报错。

原则

其实说起来 我们也没有什么原则 但是如今大多数做深度学习都是使用GPU 所以这个时候我们需要记住一个总则 只有tensor数据类型才能运行在GPU上 list和numpy都不可以。

从而 我们什么时候将我们的数据转化为tensor是一个问题 我的答案是前一节中的三个部分都可以来转化 只是我们大多数的人都习惯在部分一转化。

基础 dataset

我们必须先看看torch.utils.data.Dataset如何使用 以一个例子为例

import torch.utils.data as Data
class mydataset(Data.Dataset):
 def __init__(self,train_inputs,train_targets):#必须有
 super(mydataset,self).__init__()
 self.inputs train_inputs
 self.targets train_targets
 def __getitem__(self, index):#必须重写
 return self.inputs[index],self.targets[index]
 def __len__(self):#必须重写
 return len(self.targets)
#构造训练数据
datax torch.randn(4,3)#构造4个输入
datay torch.empty(4).random_(2)#构造4个标签
#制作dataset
dataset mydataset(datax,datay)

下面 可以对dataset进行一系列操作 这些操作返回的结果和你之前那个class的三个函数定义都息息相关。我想说 那三个函数非常自由 你想怎么定义就怎么定义 上述只是一种常见的而已 你可以定制一个特色的。

len(dataset)#调用了你上面定义的def __len__()那个函数
dataset[0]#调用了你上面定义的def __getitem__()那个函数
#(tensor([-1.1426, -1.3239, 1.8372]), tensor(0.))

所以我再三强调的是上面的输出结果和你的定义有关 比如你完全可以把def __getitem__()改成

 def __getitem__(self, index):
 return self.inputs[index]#不输出标签

那么

dataset[0]#此时当然变化。
#tensor([-1.1426, -1.3239, 1.8372])

可以看到 是非常随便的 你随便定制就好。

dataloader

torch.utils.data.DataLoader

dataloader Data.DataLoader(dataset,batch_size 2)

4个数据 batch_size 2,所以一共有2个batch。
collate_fn如果你不指定 会调用pytorch内部的 也就是说这个函数是一定会调用的 而且调用这个函数时pytorch会往这个函数里面传入一个参数batch。

def my_collate(batch):
 return xxx

这个batch是什么 这个东西和你定义的dataset, batch_size息息相关。batch是一个列表[x,...,xx],长度就是batch_size 里面每一个元素是dataset的某一个元素 即dataset[i](我在上一节展示过dataset[0] 。

在我们的例子中 由于我们没有对dataloader设置需要打乱数据 即shuffle True 那么第1个batch就是前两个数据 如下

print(datax)
print(datay)
batch [dataset[0],dataset[1]]#所以才说和你dataset中get_item的定义有关。
print(batch)


对 你没有看错 上述代码展示的batch就会传入到pytorch默认的collate_fn中 然后经过默认的处理 输出如下

it iter(dataloader)
nex next(it)#我们展示第一个batch经过collate_fn之后的输出结果
print(nex)


其实 上面就是我们常用的 经典的输出结果 即输入和标签是分开的 第一项是输入tensor 第二项是标签tensor 输入的维度变成了(batch_size,input_size)。

但是我们乍一看 将第一个batch变成上述输出结果很容易呀 我们也会 我们下面就来自己写一个collate_fn实现这个功能。

# a simple custom collate function, just to show the idea
# batch is a list of tuple where first element is input tensor and the second element is corresponding label
def my_collate(batch):
 inputs [data[0].tolist() for data in batch]
 target torch.tensor([data[1] for data in batch])
 return [data, target]
dataloader Data.DataLoader(dataset,batch_size 2,collate_fn my_collate)
print(datax)
print(datay)

it iter(dataloader)
nex next(it)
print(nex)


这不就和默认的collate_fn的输出结果一样了嘛 无非就是默认的还把输入变成了tensor,标签变成了tensor,我上面是列表 我改就是了嘛 如下

def my_collate(batch):
 inputs [data[0].tolist() for data in batch]
 inputs torch.tensor(inputs)
 target [data[1].tolist() for data in batch]
 target torch.tensor(target)
 return [inputs, target]
dataloader Data.DataLoader(dataset,batch_size 2,collate_fn my_collate)
it iter(dataloader)
nex next(it)
print(nex)

这下好了吧

对了 作为彩蛋 告诉大家一个秘密 默认的collate_fn函数中有一些语句是转tensor以及tensor合并的操作 所以你的dataset如果没有设计成经典模式的话 使用默认的就容易报错 而我们自己会写collate_fn 当然就不存在这个问题啦。同时 给大家的一个经验就是 一般dataset是不会报错的 而是根据dataset制作dataloader的时候容易报错 因为默认collate_fn把dataset的类型限制得比较死。

应用情形

假设我们还是4个输入 但是维度不固定的。

a [[1,2],[3,4,5],[1],[3,4,9]]
b [1,0,0,1]
dataset mydataset(a,b)
dataloader Data.DataLoader(dataset,batch_size 2)
it iter(dataloader)
nex next(it)

使用默认的collate_fn 直接报错 要求相同维度。

这个时候 我们可以使用自己的collate_fn 避免报错。

不过话说回来 我个人感受是

在这里避免报错好像也没有什么用 因为大多数的神经网络都是定长输入的 而且很多的操作也要求相同维度才能相加或相乘 所以 这里不报错 后面还是报错。如果后面解决这个问题的方法是 在不足维度上进行补0操作 那么我们为什么不在建立dataset之前先补好呢 所以 collate_fn这个东西的应用场景还是有限的。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267385.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号