栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

关于torch.cumprod()累积连乘

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

关于torch.cumprod()累积连乘

import torch
in_=torch.tensor([[2., 4., 6.], [1., 3., 5.]])
print(in_)
out_prod = torch.cumprod(in_,dim=0)#竖着累积
print("cumulative product:", out_prod)
out_prod = torch.cumprod(in_,dim=1)#横着累积
print("cumulative product:", out_prod)
out1=torch.tensor([[[0., 1.0, 3., 0.],
                    [0., 0., 0., 0.],
                    [0., 0., 0., 0.]],

                    [[0., 0., 0., 0.],
                    [0., 0., 0., 0.],
                    [0., 0., 0., 0.]]])
print(out1.shape)
print(out1[:, :, 1]==0)#维度为1的数据为0则是true,否则是false
print((out1[:,:,1]==0).float())#将布尔类型转化为true为1,false为0
print((out1[:,:,0]==0).float())
print((out1[:,:,0]==0).float()*(out1[:,:,1]==0).float())
mask = torch.cumprod( (out1[:, :, 0] == 0).float() * (out1[:, :, 1] == 0).float(), dim=0)#先对应相乘然后再按竖着相乘
print(mask)

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/744138.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号