栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch报错: Can only calculate the mean of floating types. Got Long instead

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch报错: Can only calculate the mean of floating types. Got Long instead

小问题不要慌!!!!
运行代码:

import sys
sys.path.append('..')
import torch

def simple_batch_norm_1d(x, gamma, beta):
    eps = 1e-5
    x_mean = torch.mean(x, dim=0, keepdim=True)  # dim=0在每一列上求取均值  保留维度进行 broadcast
    x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

#  5行3列表示三个特征,每个特征上有五个数据点
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
y = y.float()
print('after bn: ')
print(y)

该代码是学习pytorch数据标准化的代码,对一个tensor求一个均值和方差。
报错如下:

该错误提示也很明显,在求均值的时候数据类型不对,计算得到的是个long型,对其数据类型做个转换即可。
修改如下:

x_mean = torch.mean(x.float(), dim=0, keepdim=True) 

这是运行就没错误啦!!!!

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/503785.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号