栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch常用mask命令

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch常用mask命令

文章目录
  • 前言
  • 1.Tensor.masked_fill_(mask, value)
    • 举个例子
  • 2.torch.masked_select(input, mask, *, out=None) → Tensor
    • 举个例子
  • 3.Tensor.masked_scatter_(mask, source)
    • 举个例子


前言


mask是深度学习里面常用的操作,最近在研究transformer的pytorch代码,总能看到各种mask的命令,在这里总结一下

1.Tensor.masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.

Parameters
mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with
举个例子
import torch
mask = torch.tensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]]).bool()
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(3,3)
a.masked_fill(mask, 0)
# tensor([[ 0.0000,  0.6781,  0.6532],
#         [-1.2078,  0.0000,  0.4964],
#         [ 0.2192, -0.6276,  0.0000]])
a.masked_fill(~mask, 0)#可以对mask取反
# tensor([[-0.4438,  0.0000,  0.0000],
#         [ 0.0000,  1.3907,  0.0000],
#         [ 0.0000,  0.0000,  2.2462]])
2.torch.masked_select(input, mask, *, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

(注意)The returned tensor does not use the same storage as the original tensor

Parameters
input (Tensor) – the input tensor.
mask (BoolTensor) – the tensor containing the binary mask to index with
举个例子
import torch
x = torch.randn(3,4)
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
mask = x > 0.5
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
torch.masked_select(x, mask)
# tensor([3.0989, 1.9527, 0.8310])
3.Tensor.masked_scatter_(mask, source)

Tensor.masked_scatter_(mask, source)
Copies elements from source into self tensor at positions where the mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. The source should have at least as many elements as the number of ones in mask

source大小和mask至少一样,能够被广播到Tensor上,或者source和Tensor一样
作用就是把source里mask是true的位置挑出来给Tensor

Parameters
mask (BoolTensor) – the boolean mask
source (Tensor) – the tensor to copy from
举个例子
import torch
mask = torch.BoolTensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]])
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(2,3,3)
s = torch.ones_like(a)
a.masked_scatter(mask, s)
# tensor([[[ 1.0000, -0.1560, -0.7760],
#          [-0.5192,  1.0000, -0.1709],
#          [ 0.2091,  0.5650,  1.0000]],

#         [[ 1.0000,  0.0623, -0.1447],
#          [-1.2910,  1.0000, -1.2722],
#          [-0.7864, -0.1118,  1.0000]]])
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/503916.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号