在PyTorch中同时使用torch.cuda.amp与torch.utils.checkpoint进行训练时,出现了数据类型不一致的问题,具体报错如下所示:
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
scaler = amp.GradScaler()
with amp.autocast():
outputs = model(inputs) # model中使用了gradient checkpoint
loss = criterion(outputs,target)
scaler.scale(loss).backward() # ->此处报错
- torch==1.7.1
原因分析:
#37730@ptrblck
Thanks for raising this issue!
My best guess is that CheckpointFunction.backward uses the stored mixed-precision tensors from its forward, but breaks the autocasting contract for the backward.
If you run scaler.scale(loss).backward() inside the autocast block, it should work for now as a workaround.
在计算反向传播的过程中会使用checkpoint中保存的梯度,然而这个梯度在前向传播过程中是以FP16计算并保存的,但在autocast()上下文之外,模型参数的精度是FP32的,因此出现了类型不一致无法计算的问题。
解决方案:
@ptrblck提出的方案可以解决这个问题,即将反向传播也放在autocast()的上下文之内,使amp模块可以自动将模型参数的精度与checkpoint中保存的梯度进行匹配。
另外#49757直接修改了checkpoint模块的源码,这个pull request已经merge到了PyTorch的主分支中,因此1.8.0以及以上的版本已经修复了这个问题。



