栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

StopIteration: Caught StopIteration in replica 0 on device 0. 问题排查与解决

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

StopIteration: Caught StopIteration in replica 0 on device 0. 问题排查与解决

首先是错误内容截图:(抱歉因为打码有点糊)

我在训练修改后的TransformerXL时,发现了如上的错误,此前代码已经成功地在单GPU下运行过,切换到多卡运行出现该问题。尝试进行解决。

使用的环境是: Pytorch1.11 transformers:4.18

在网上进行查阅后大部分人都说可能是pytorch版本的问题,当前所使用的pytorch版本过高,需要降级到1.4.0版本。

降级听起来比较简单,但是我不想降级到太低的版本,只能走第二条路,修改代码。

首先定位到出错的非源码的最后一行,

param = next(self.parameters())

经过上网查找,发现可能是在训练过程中部分数据的精度不同导致的问题,可能同时存在16位精度和32位精度的数据,尝试在这里进行修改,将其直接指定为torch.float32 进行训练。

原始代码为:

    def init_mems(self):
        if self.mem_len > 0:
            mems = []
            param = next(self.parameters())
            for i in range(self.n_layer+1):
                empty = torch.empty(0, dtype=param.dtype, device=param.device)
                mems.append(empty)
            return mems
        else:
            return None

更改后的代码是: 

    def init_mems(self):
        if self.mem_len > 0:
            mems = []
            for i in range(self.n_layer+1):
                empty = torch.empty(0, dtype=torch.float32).cuda()
                mems.append(empty)
            return mems
        else:
            return None

成功! 问题解决! 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/980113.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号