栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

amp与gradient checkpoint冲突

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

amp与gradient checkpoint冲突

问题描述:

在PyTorch中同时使用torch.cuda.amp与torch.utils.checkpoint进行训练时,出现了数据类型不一致的问题,具体报错如下所示:

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

scaler = amp.GradScaler()
with amp.autocast():
    outputs = model(inputs) # model中使用了gradient checkpoint
    loss = criterion(outputs,target)
scaler.scale(loss).backward() # ->此处报错
  • torch==1.7.1

原因分析:

#37730@ptrblck

Thanks for raising this issue!
My best guess is that CheckpointFunction.backward uses the stored mixed-precision tensors from its forward, but breaks the autocasting contract for the backward.
If you run scaler.scale(loss).backward() inside the autocast block, it should work for now as a workaround.

在计算反向传播的过程中会使用checkpoint中保存的梯度,然而这个梯度在前向传播过程中是以FP16计算并保存的,但在autocast()上下文之外,模型参数的精度是FP32的,因此出现了类型不一致无法计算的问题。


解决方案:

@ptrblck提出的方案可以解决这个问题,即将反向传播也放在autocast()的上下文之内,使amp模块可以自动将模型参数的精度与checkpoint中保存的梯度进行匹配。
另外#49757直接修改了checkpoint模块的源码,这个pull request已经merge到了PyTorch的主分支中,因此1.8.0以及以上的版本已经修复了这个问题。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/272828.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号