栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

有趣的torch.einsum

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

有趣的torch.einsum

import torch
import numpy as np
a = torch.arange(9).reshape(3, 3)
提取矩阵对角线元素
out = torch.einsum('ii->i', a)	# tensor([0, 4, 8])
矩阵转置
out = torch.einsum('ij->ji', a)
out = torch.einsum('...ij->...ji', a) # 高维矩阵最后两维转置
reduce sum
out = torch.einsum('ij->', a)	# tensor(36)
矩阵按列求和
out = torch.einsum('ki->i', a)
矩阵向量乘法
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
out = torch.einsum('ik,k->i', a, b)
out = torch.einsum('ik,k', a, b)	# 箭头右侧符号可以不写,按规则默认推理。
矩阵乘法
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
out = torch.einsum('ik,kj->ij', a, b)
out = torch.einsum('ik,kj', a, b)
向量内积
a = torch.arange(3)
b = torch.arange(3, 6)
out = torch.einsum('i,i->', a, b)
out = torch.einsum('i,i', a, b)
矩阵元素对应相乘并求reduce sum
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
out = torch.einsum('ij,ij->', a, b)
向量外积
a = torch.arange(3)
b = torch.arange(3,7)
out = torch.einsum('i,j->ij', a, b)
batch矩阵乘法
a = torch.randn(2,3,5)
b = torch.randn(2,5,4)
out = torch.einsum('bik,bkj->bij', a, b)
张量收缩

tensor contraction, 用不上,暂时看不懂。

a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
out = torch.einsum('pqrs,tuqvr->pstuv', a, b)
双线性变换

bilinear transformation. Applies a bilinear transformation to the incoming data.

a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
out = torch.einsum('ik,jkl,il->ij', a, b, c)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/326168.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号