栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch基础知识九【统计属性】

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch基础知识九【统计属性】

统计属性
  • 1. 范数
  • 2. p范数
  • 3. 常用统计属性
  • 4. 高级操作
    • 4.1 where
    • 4.2 gather

1. 范数

2. p范数

3. 常用统计属性

【1】mean、sum、min、max、prod;argmin、argmax

(1) prod()  表示连乘
(2) argmax()、argmin() 不指定维度,会将张量(tensor)打平乘一个一维的tensor,返回索引;
	指定维度后,根据维度返回每一行或每一列中max或min元素的索引。




【2】dim、keepdim

keepdim参数在一列tensor的形式返回所求目标统计值和对应的索引


【3】topk 和 kth

【1】topk
		(1) b的size是p* q* j
		(2) b.topk(n,dim=0) 返回的东西包括两个nqj的tensor,
			第一个tensor是value 第二个tensor是indice 其他思路和max差不多
		(3) b.topk(n,dim=1) 返回的东西包括两个pnj的tensor,
			第一个tensor是value 第二个tensor是indice 其他思路和max差不多
		(4) dim不写 默认为最大的 比如本例中就是2
		(5) largest=False则返回的是最小的n个
	b = torch.rand(2,3,10) #假设2个batch,每个batch3张照片,10表示每张照片依次是0~9的概率
	print(b)
	tensor([[[0.5304, 0.1505, 0.5322, 0.0247, 0.1890, 0.8630, 0.6212, 0.4308,
          0.3447, 0.9590],
         [0.7632, 0.1420, 0.7258, 0.8698, 0.6531, 0.2155, 0.8730, 0.3963,
          0.1711, 0.4374],
         [0.7572, 0.4117, 0.2699, 0.4153, 0.9025, 0.7338, 0.4403, 0.6043,
          0.6109, 0.9631]],

        [[0.5738, 0.9288, 0.2986, 0.8859, 0.2687, 0.9388, 0.4515, 0.9561,
          0.9277, 0.6534],
         [0.4825, 0.9142, 0.4260, 0.4814, 0.8739, 0.9296, 0.8225, 0.5581,
          0.8214, 0.2455],
         [0.5460, 0.6100, 0.5846, 0.7510, 0.5964, 0.4243, 0.8347, 0.7734,
          0.5279, 0.5943]]])

	b.topk(1,dim=2)
	# 0.9590表示第一个batch的第一张照片的10个概率值中,最大的是0.9590 并
	#且我们可以看到它对应的indice是9 表示这张照片是“9”的概率为0.9590
	
	# 0.9296表示第2个batch的第2张照片的10个概率值中,最大的是0.9296 并且
	# 我们可以看到它对应的indice是5 表示这张照片是“5”的概率为0.9590

	torch.return_types.topk(
	values=tensor([[[0.9590],
	         [0.8730],
	         [0.9631]],
	
	        [[0.9561],
	         [0.9296],
	         [0.8347]]]),
	indices=tensor([[[9],
	         [6],
	         [9]],
	
	        [[7],
	         [5],
	         [6]]]))

	b.topk(1)
	torch.return_types.topk(
	values=tensor([[[0.9590],
	         [0.8730],
	         [0.9631]],
	
	        [[0.9561],
	         [0.9296],
	         [0.8347]]]),
	indices=tensor([[[9],
	         [6],
	         [9]],
	
	        [[7],
	         [5],
	         [6]]]))

【2】kthvalue
		kthvalue的分析思路和max带dim的分析思路一样,只是把最大换成第几大
		当kthvalue的dim不写的时候,默认为最大
	print(b.shape,'n')
	print(b)
	torch.Size([2, 3, 10]) 

	tensor([[[0.5304, 0.1505, 0.5322, 0.0247, 0.1890, 0.8630, 0.6212, 0.4308,
	          0.3447, 0.9590],
	         [0.7632, 0.1420, 0.7258, 0.8698, 0.6531, 0.2155, 0.8730, 0.3963,
	          0.1711, 0.4374],
	         [0.7572, 0.4117, 0.2699, 0.4153, 0.9025, 0.7338, 0.4403, 0.6043,
	          0.6109, 0.9631]],
	
	        [[0.5738, 0.9288, 0.2986, 0.8859, 0.2687, 0.9388, 0.4515, 0.9561,
	          0.9277, 0.6534],
	         [0.4825, 0.9142, 0.4260, 0.4814, 0.8739, 0.9296, 0.8225, 0.5581,
	          0.8214, 0.2455],
	         [0.5460, 0.6100, 0.5846, 0.7510, 0.5964, 0.4243, 0.8347, 0.7734,
	          0.5279, 0.5943]]])

	print(b.kthvalue(2,dim=2)[0].shape,'n')
	print(b.kthvalue(2,dim=2))
	
	torch.Size([2, 3]) 

	torch.return_types.kthvalue(
	values=tensor([[0.1505, 0.1711, 0.4117],
	        [0.2986, 0.4260, 0.5279]]),
	indices=tensor([[1, 8, 1],
	        [2, 2, 8]]))



【4】compare

4. 高级操作

where & gather

4.1 where


4.2 gather

官方文档

定义:
	torch.gather(input, dim, index, out=None) → Tensor

    Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

Parameters: 

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

a=t.arange(0,16).view(4,4)
print(a)

index_1=t.LongTensor([[3,2,1,0]])
b=a.gather(0,index_1)
print(b)

index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t()
c=a.gather(1,index_2)
print(c)


执行结果:
	tensor([[ 0,  1,  2,  3],
	        [ 4,  5,  6,  7],
	        [ 8,  9, 10, 11],
	        [12, 13, 14, 15]])
        
	tensor([[12,  9,  6,  3]])
	
	tensor([[ 0],
	        [ 5],
	        [10],
	        [15]])


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/656185.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号