栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch积少成多

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch积少成多

torch.nn.Softmax(-1)
X = torch.randn(3,2)
Y = nn.Softmax(dim = -1)(X)
print(X)
print('---')
print(Y)

tensor([[1.6717, 0.1819],
        [1.3746, 1.0038],
        [0.0052, 0.3082]])
---
tensor([[0.8161, 0.1839],
        [0.5917, 0.4083],
        [0.4248, 0.5752]])

可以看到通过dim可以对矩阵的某个维度求softmax,dim=-1表示最后一个维度,即对每一行求。

Pytorch中的torch.cat()函数

C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼)
C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)

import torch
>>> A=torch.ones(2,3)    #2x3的张量(矩阵)                                     
>>> A
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> B=2*torch.ones(4,3)  #4x3的张量(矩阵)                                    
>>> B
tensor([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]])
>>> C=torch.cat((A,B),0)  #按维数0(行)拼接
>>> C
tensor([[ 1.,  1.,  1.],
         [ 1.,  1.,  1.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
>>> C=torch.cat((A,D),1)#按维数1(列)拼接
>>> C
tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
>>> C.size()
torch.Size([2, 7])
Pytorch中的torch.repeat()函数

相当于把某个整体,在某个方向广播。

import torch
x = torch.tensor([1, 2, 3])
print(x.repeat(4, 1))
print("###################################")
print(x.repeat(4, 2, 1))
print("###################################")
print(x.repeat(4, 1, 2))

output:
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])
###################################
tensor([[[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]]])
###################################
tensor([[[1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3]]])
Pytorch中的torch.bmm()函数

torch.bmm(input, mat2, out=None) → Tensor

torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。
input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。
并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/295918.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号