栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【Pytorch】Pytorch基础

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【Pytorch】Pytorch基础

张量的结构操作 一、创建张量

张量创建的方法和Numpy中创建array的方法十分相似。

1.1 从Python列表或者元组创建张量
a = torch.tensor([1,2,3], dtype=torch.float)
a = torch.tensor((1,2,3), dtype=torch.float)
1.2 使用arange生成张量
b = torch.arange(start=1, end=10, step=1)
1.3 使用linspace/logspace生成张量
c = torch.linspace(start=0, end=10, steps=10, requires_grad=True)
# 注意torch.linspace/logspace中的steps参数和torch.arange中的step参数的区别
c = torch.logspace(start=0, end=10, steps=10, base=10, requires_grad=False)
1.4 使用ones/zeros创建张量
d = torch.zeros((3,3))
d = torch.ones((2,3))

需要注意的是torch.zeros_like或torch.ones_like,二者可以快速生成给定tensor一样shape的0或1向量。

e = torch.zeros_like(d, dtype=torch.int)
e = torch.ones_like(d, dtype=torch.float)
1.5 创建随机张量
# torch.randint --> Returns a tensor filled with random integers generated uniformly
g = torch.randint(low=0, high=10, size=[2,2])
# 0-1均匀分布
f = torch.rand([5])
# 均匀随机分布
f = torch.randn([5])
# 正态随机分布
# mean (Tensor): the tensor of per-element means
# std (Tensor): the tensor of per-element standard deviations
f = torch.normal(mean=torch.zeros(3,3),std=torch.ones(3,3))
# 整数随机排列
# torch.randperm --> Returns a random permutation of integers from ``0`` to ``n - 1``.
f = torch.randperm(20)
1.6 创建特殊矩阵
# 单位矩阵
g = torch.eye(2,2)
# 对角矩阵
# 注意torch.diag的输入必须是一个tensor
g = torch.diag(torch.tensor([1,2,3]))
二、索引切片

张量的索引和切片与Numpy亦十分类似,切片时支持缺省函数和省略号,也可以通过索引和切片对部分元素进行修改。

# 使用省略号可以表示多个冒号
In[0]: print(a)
Out[0]: tensor([[[ 0,  1,  2],
		         [ 3,  4,  5],
		         [ 6,  7,  8]],
		
		        [[ 9, 10, 11],
		         [12, 13, 14],
		         [15, 16, 17]],
		
		        [[18, 19, 20],
		         [21, 22, 23],
		         [24, 25, 26]]])
Out[1]: print(a[...,1])
Out[1]: tensor([[ 1,  4,  7],
		        [10, 13, 16],
		        [19, 22, 25]])

对于不规则的切片提取,可以采用如torch.index_select、torch.take、torch.gather、torch.masked_select等方法。上述这些方法可以完成提取张量的部分元素值,但不能更改张量的部分元素值得到新的张量。如果需要修改张量的部分元素得到新的张量,可以使用torch.where、torch.index_fill、torch.masked_fill;其中torch.index_fill和torch.masked_fill选取元素逻辑分别与torch.index_select和torch.masked_select相同。

2.1 torch.index_select

Pytorch: torch.index_select
该函数有三个参数:

    input:即被索引的张量dim:即索引的维度index:index参数属性为IntTensor或者LongTenosr,index是一个一维保存期望索引目标的序列( the 1-D tensor containing the indices to index)
2.2 torch.take

Pytorch: torch.take
t o r c h . t a k e torch.take torch.take函数首先将输入的Tensor展开为一维张量,输出一个与 i n d e x index index参数相同shape的张量;该函数有两个参数:

    input:输入张量index:该参数属性为LongTensor,存储我们期望索引数据的索引下标
2.3 torch.gather

Pytorch: torch.gather

2.4 torch.masked_select

Pytorch: torch.masked_select
该函数返回一个一维的张量,这个张量由输入的张量map一个为布尔张量的mask选择得到。

    input (Tensor) – the input tensor.mask (BoolTensor) – the tensor containing the binary mask to index with
2.5 torch.where

Pytorch: torch.where
参数:

    condition: 如果condition为True,返回x,否则返回yx: 从condition这个boolean张量为True的index返回x对应位置的元素。y: 元素选择逻辑与x相同
三、维度变换

Pytorch中用于维度变换的函数主要有torch.reshape、torch.squeeze、torch.unsqueeze、torch.transpose

3.1 torch.squeeze

Pytorch: torch.squeeze
如果张量在某个维度上只有一个元素,使用这个函数可以消除这个维度,如将 t o r c h . S i z e ( [ 1 , 2 ] ) torch.Size([1,2]) torch.Size([1,2])形状的张量变为 t o r c h . S i z e ( [ 2 ] ) torch.Size([2]) torch.Size([2])
torch.unsqueeze的作用与该函数作用效果相反。

3.2 torch.transpose

Pytorch: torch.transpose
该函数用于交换张量的维度,常用于图片存储格式的变换上。如果张量是一个二维的矩阵,通常会使用 m a t r i x . t ( ) matrix.t() matrix.t(),这个操作等价于 t o r c h . t r a n s p o s e ( m a t r i x , 0 , 1 ) torch.transpose(matrix, 0, 1) torch.transpose(matrix,0,1)
参数为:

    input:输入张量dim0:第一个需要被转置的维度dim1:第二个需要被转置的维度
四、合并分割

Pytorch中提供了torch.stack、torch.cat来将多个张量合并,torch.split将一个张量分割为多个张量。注意torch.stack会增加维度,而torch.cat只是连接。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/768135.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号