栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch的 nn.CrossEntropyLoss()报错

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch的 nn.CrossEntropyLoss()报错

nn.CrossEntropyLoss()
中两个参数,其中的标签必须为long型(int64)的,不能是float32

hwLabels = torch.Tensor(hwLabels).long()
loss_func = nn.CrossEntropyLoss() 
 for epoch in range(EPOCH):
        for step, (b_x, b_y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
            output = cnn(b_x)[0]  # cnn output
            loss = loss_func(output, b_y)  # cross entropy loss
            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients
            if step%5==0:
                loss_count.append(loss.detach().numpy())
                print('{}:t'.format(step),"tloss:",loss.item())
                #torch.save(cnn,r'save/model')
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/649919.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号