栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch入门与实战_pytorch 示例?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch入门与实战_pytorch 示例?

Pytorch基础语法记录

1.dim、size、shape的区别:link1,link2
-dim=0的标量
维度为0的Tensor为标量,标量一般用在Loss这种地方。如下代码定义了一个标量:a = torch.tensor(1.6)。定义标量的方式很简单,只要在tensor函数中传入一个标量初始化的值即可,注意是具体的数据。而torch.Tensor()是Pytorch中的一个类,是默认张量类型torch.FloatTensor()的别名。注意一点,torch.tensor()参数接收的是具体的数据,而torch.Tensor()参数既可以接收数据也可以接收维度分量也就是shape。
-dim=1的张量
dim=1的Tensor一般用在Bais这种地方,或者神经网络线性层的输入Linear Input,例如MINST数据集的一张图片用shape=[784]的Tensor来表示。

dim=1相当于只有一个维度,但是这个维度上可以有多个分量(就像一维数组一样),一维的张量实现方法有很多,下面是三种实现:

def printMsg(k):
	"""输出Tensor的信息,维度,shape"""
	print(k, k.dim(), k.size(), k.shape)

# 1.通过torch.tensor(), 参数传入一个list构造dim=1的Tensor
a = torch.tensor([1.1])
printMsg(a)
# tensor([1.1000]) 1 torch.Size([1]) torch.Size([1])
b = torch.tensor([1.1, 2.2])
printMsg(b)
# tensor([1.1000, 2.2000]) 1 torch.Size([2]) torch.Size([2])

# 2.通过torch.Tensor(), 随机构造dim=1的Tensor
# 这里传入的是shape=1,有1个元素
c = torch.FloatTensor(1)
printMsg(c)
# tensor([1.5056e-38]) 1 torch.Size([1]) torch.Size([1])
# 这里传入的是shape=2,有2个元素
d = torch.FloatTensor(2)
printMsg(d)
# tensor([0., 0.]) 1 torch.Size([2]) torch.Size([2])

# 3.从numpy构造dim=1的Tensor
e = np.ones(2)
print(e)
# array([1., 1.])
e = torch.from_numpy(e)
printMsg(e)
# tensor([1., 1.], dtype=torch.float64) 1 torch.Size([2]) torch.Size([2])

dim=2的张量

dim=2的张量一般用在带有batch的Linear Input,例如MNIST数据集的k张图片如果放再一个Tensor里,那么shape=[k,784]。

# dim=2,shape=[2,3],随机生成Tensor
a = torch.FloatTensor(2, 3)
 
print(a.shape)		#torch.Size([2, 3])
print(a.shape[0])	#2
print(a.shape[1])	#3
print(a.size())	#torch.Size([2, 3])
print(a.size(0))	#2
print(a.size(1))	#3

dim=4的张量
dim=4的张量适合用于CNN表示图像,例如100张MNIST手写数据集的灰度图(通道数为1,如果是RGB图像通道数就是3),每张图高=28像素,宽=28像素,所以这个Tensor的shape=[100,1,28,28],也就是一个batch的数据维度:[batch_size,channel,height,width] 。

如下构建一个shape=[2, 3, 28, 28]的Tensor:a = torch.rand(2, 3, 28, 28)

2.Tensor()和FloatTensor()里面的参数是现成的list数据或者shape,而tensor()【小写的t】参数是现成的数据如list数据。
Tensor()默认等于FloatTensor(),这个默认值也可以修改。

3.rand()函数,如a=torch.rand(3,3),初始化一个3行3列的张量,每个元素值都在[0, 1]区间。要想取[0,10]区间,应使用10*torch.rand(d1,d2),randint只能采样整数项。
rand_like()函数,如torch.rand_like(a),相当于先获取a的shape,再传给torch.rand()
randn(): 初始元素的值取自正态分布。

4.arange(0,10):生成一个包括0但不包括10的等差数列,差值为1。
arange(0,10,2):生成一个包括0但不包括10的等差数列,差值为2。

5.Torch.linspace(0,10,steps=4),输出为tensor([ 0.0000, 3.3333, 6.6667, 10.0000]),也就是把[0,10]切割为相等的4份。

6.torch.eye:对角矩阵 torch.ones:单位矩阵 torch.zeros:全为0的矩阵。

7.切片操作

8.view=reshape,都是进行维度变换。(size不能变,变的是shape)

9.a.unsqueeze(index),在指定下标位置增加一个维度。
如:

>>>a=torch.rand(4,1,28,28)
>>>a.shape
torch.Size([4, 1, 28, 28])
>>>a.unsqueeze(0).shape
torch.Size([1, 4, 1, 28, 28])
>>>a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])
>>>a.unsqueeze(3).shape
torch.Size([4, 1, 28, 1, 28])

10.expand和repete维度变换(建议使用expand)
expand:

>> x = torch.randn(2, 1, 1, 4)
>> x.expand(-1, 2, 3, -1)		#expand给的参数是新的shape,参数为-1的位置维数不会改变,而且只能改变原来维数为1的位置
torch.Size([2, 2, 3, 4])

repete:

>>>b.shape
torch.Size([1, 32, 1, 1])
>>>b = torch.Tensor(1,32,4,4)
>>>b.shape
torch.Size([1, 32, 4, 4])
>>>b.repeat(4,32,1,1).shape		#repete的参数代表对应位置维数要repete的次数
torch.Size([4, 1024, 4, 4])	#1*4=4,32*32=1024, 1*4=4, 1*4=4

11.transpose和permute维度变换
a.transpose(dim0, dim1, out=None)
函数返回输入矩阵的转置。交换维度dim0和dim1
参数:
dim0 (int) – 转置的第一维,默认0,可选
dim1 (int) – 转置的第二维,默认1,可选
注意只能有两个相关的交换的位置参数。

permute()
参数:
dims (int…*)-换位顺序,必填

transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数。

12.Broadcasting自动扩展
13.合并与分割
-合并
cat

>>>a1 = torch.rand(4,3,32,32)
>>>a2 = torch.rand(5,3,32,32)
>>>torch.cat([a1,a2],dim=0).shape
torch.Size([9, 3, 32, 32])
>>>a2 = torch.rand(4,1,32,32)
>>>torch.cat([a1,a2],dim=0).shape
Traceback (most recent call last):
  File "", line 1, in 
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 1 for tensor number 1 in the list.
>>>torch.cat([a1,a2],dim=1).shape
torch.Size([4, 4, 32, 32])
>>>a1 = torch.rand(4,3,16,32)
>>>a2 = torch.rand(4,3,16,32)
>>>torch.cat([a1,a2],dim=2).shape
torch.Size([4, 3, 32, 32])

stack
stack会增加一个维度,且要求被合并的两个元素维度一致。

-分割
split:by len
chunk: by num

14.数学运算
-矩阵乘法:

@是matmul的重载符,相当于matmul的另一种写法

-开方
power或者**

-clamp
clamb(min)
clamb(min,max)

15.属性统计
-norm
torch.norm的理解
-prod
累乘
-argmin、argmax

-eq和equal

17.高阶操作
-where

-gather
参考链接:pytorch之torch.gather方法

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/783072.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号