torch.argmax(input) → LongTensor
Returns the indices of the maximum value of all elements in the input tensor. # 返回输入张量中所有元素的最大值的索引。
参数:input (Tensor) – the input tensor例子:
>>> a = torch.randn(3,3)
>>> a
tensor([[-0.0368, 0.0057, -1.5687],
[-0.2456, 0.0145, -0.4154],
[ 1.0114, -0.4180, -0.5612]])
>>> print(torch.argmax(a))
tensor(6) # 从0开始计数,从左往右,从上往下
>>>
语法一:
torch.argmax(input, dim, keepdim=False) → LongTensor参数:
-
input (Tensor) – the input tensor #输入张量
-
dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned. # 缩小尺寸。如果为None,则返回平坦输入的argmax。
-
keepdim (bool) – whether the output tensors have dim retained or not. Ignored if dim=None.
>>> import torch
>>> b = torch.randn(4,4)
>>> print(b)
tensor([[-1.5364, 1.6827, -0.0245, -0.1265],
[ 0.6040, -0.8682, 0.3914, 0.5424],
[-0.6569, 1.2815, 0.3952, 0.6946],
[-1.1316, 0.7783, 1.2647, -0.4944]])
>>> print(torch.argmax(b,dim =0)) #竖着比较,找最大
tensor([1, 0, 3, 2])
>>> print(torch.argmax(b,dim =1)) #横着比较,找最大
tensor([1, 0, 1, 2])
更复杂一些的例子:
>>> c= torch.randn(2,3,4)
>>> print(c)
tensor([[[ 0.1911, -1.3272, -0.1704, -1.0493],
[ 1.0991, -0.4143, -0.3800, -0.4657],
[-0.3569, -0.6414, 1.3495, -0.0230]],
[[-2.1686, -1.1714, -0.3639, 0.5945],
[-0.4642, 0.8249, -0.0173, 0.1934],
[-0.1629, 1.2108, 1.6179, -0.2537]]])
>>> print(torch.argmax(c,dim=0))
tensor([[0, 1, 0, 1],
[0, 1, 1, 1],
[1, 1, 1, 0]])
>>> print(torch.argmax(c,dim=1))
tensor([[1, 1, 2, 2],
[2, 2, 2, 0]])
>>> print(torch.argmax(c,dim=2))
tensor([[0, 0, 2],
[3, 1, 2]])



