栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch基础知识八【基本数学运算】

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch基础知识八【基本数学运算】

基本数学运算
  • 1. 加减乘除
  • 2. 矩阵乘法
  • 3. 开方
  • 4. 近似运算

1. 加减乘除
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
print(a)
b = torch.tensor([[10,20,30],[40,50,60],[70,80,90]])
# 加法
print(a+b)
print(torch.add(a,b))

# 减法
print(torch.all(torch.eq(a-b,torch.sub(a,b))))

# 乘法
print(torch.all(torch.eq(a*b,torch.mul(a,b))))

# 除法
print(torch.all(torch.eq(a/b,torch.div(a,b))))


执行结果:
	tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
	tensor([[11, 22, 33],
	        [44, 55, 66],
	        [77, 88, 99]])
	tensor([[11, 22, 33],
	        [44, 55, 66],
	        [77, 88, 99]])
	tensor(True)
	tensor(True)
	tensor(True)


2. 矩阵乘法



举例:

3. 开方

【1】普通的乘方、开方运算

aa.rsqrt()  表示先对aa开平方,然后对开平方的结果求倒数
pow(aa,0.5) 表示开平方运算


【2】自然数e的乘方、对数运算

exp(n) 表示:e的n次方
log(a) 表示:ln(a)
log2() 、 log10()

In[18]: a = torch.exp(torch.ones(2,2))
In[19]: a
Out[19]: 
tensor([[2.7183, 2.7183],
        [2.7183, 2.7183]])

In[20]: torch.log(a)
Out[20]: 
tensor([[1., 1.],
        [1., 1.]])

In[22]: torch.log2(a)
Out[22]: 
tensor([[1.4427, 1.4427],
        [1.4427, 1.4427]])

In[23]: torch.log10(a)
Out[23]: 
tensor([[0.4343, 0.4343],
        [0.4343, 0.4343]])
4. 近似运算

【1】取整、四舍五入、裁剪

floor、ceil 向下取整、向上取整
round 4舍5入
trunc、frac 裁剪


In[24]: a = torch.tensor(3.14)
In[25]: a.floor(),a.ceil(),a.trunc(),a.frac()
Out[25]: (tensor(3.), tensor(4.), tensor(3.), tensor(0.1400))

In[26]: a = torch.tensor(3.499)
In[27]: a.round()
Out[27]: tensor(3.)

In[28]: a = torch.tensor(3.5)
In[29]: a.round()
Out[29]: tensor(4.)

【2】clamp

		torch.clamp(input, min, max, out=None) → Tensor
			将输入input张量每个元素的夹紧到区间 [min,max][min,max],
			并返回结果到一个新张量。

操作定义如下:
		      | min, if x_i < min
		y_i = | x_i, if min <= x_i <= max
      		  | max, if x_i > max
		
	(1) gradient clipping 梯度裁剪
	(2) (min) 小于min的都变为某某值
	(3) (min, max) 不在这个区间的都变为某某值
	(4) 梯度爆炸:一般来说,当梯度达到100左右的时候,就已经很大了,正常在10左右,通过打印梯度的模来查看 w.grad.norm(2)
	(5) 对于w的限制叫做weight clipping,对于weight gradient clipping称为 gradient clipping。



In[30]: grad = torch.rand(2,3)*15

In[31]: grad.max()
Out[31]: tensor(10.6977)

In[32]: grad
Out[32]: 
tensor([[ 6.7738, 10.6977,  4.4314],
        [ 7.8088,  4.8236,  3.6213]])

In[33]: grad.clamp(10)		# 小于10的都变为10
Out[33]: 
tensor([[10.0000, 10.6977, 10.0000],
        [10.0000, 10.0000, 10.0000]])


In[34]: grad.clamp(0,10)   # 不在(0,10)区间的都变为10
Out[34]: 
tensor([[ 6.7738, 10.0000,  4.4314],
        [ 7.8088,  4.8236,  3.6213]])
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/656706.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号