栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch grid

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch grid

grid_sample函数

这篇博客只对bilinear mode进行解释说明,并且会对align_corners为True或False两种情况进行分情况讨论。

torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zero’, align_corners=None)

nn.functional下的grid_sample函数会根据提供的坐标(grid)对input pixels进行采样(sampling),这篇文章只以bilinear interpolation sampling为例。 根据官方文档介绍,input shape必须是4D或5D的,分别用于二维和三维图像的采样(前两个维度为batch size和channel)。

input的shape(4D case)是 ( N , C , H i n , W i n ) (N, C, H_{in}, W_{in}) (N,C,Hin​,Win​), 这个很好理解。

gird的shape(4D case)是 ( N , H o u t , W o u t , 2 ) (N, H_{out}, W_{out}, 2) (N,Hout​,Wout​,2), 这里的H和W是output的长和宽,有一点需要注意的是,grid_sample的output shape是 ( N , C , H o u t , W o u t ) (N, C, H_{out}, W_{out}) (N,C,Hout​,Wout​), 所以output的shape和grid的shape是一样的, 而不是和input的shape一样。grid的最后一个维度2表示的是x,y坐标, 如果是5D的情况,也就是处理三维图像的时候,gird的最后一个维度就是3,因为需要引入z坐标。

grid表示的是的sampling pixel的坐标,这个坐标是被normalized过的,grid坐标取值范围为[-1, 1]。 点(-1,-1)为左上角的pixel,(1,1)为右下的pixel。中间的坐标值为某个浮点数。

grid_sample函数做的就是根据grid坐标,从input的pixels里采样。 如果此坐标下没有对应的input pixel,就要用bilinear interpolation从周围的pixels采样。

下面是Piotr给出的一个例子
https://discuss.pytorch.org/t/solved-torch-grid-sample/51662/2

inp = torch.arange(4*4).view(1, 1, 4, 4).float()
d = torch.linspace(-1, 1, 8)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2).unsqueeze(0)
output = torch.nn.functional.grid_sample(inp, grid, align_corners=False)

meshy是x坐标

meshx是y坐标

align_corners=True

当align_corners=True时,以坐标(-0.7143, -0.7143)为例,请看下图。
因为align_corners=True,所以(-1, -1)点的值为0, (1, 1)点的值为15,可以认为grid的-1和1在是在corner pixel的中心位置。由此可以推出值为1和2的坐标为(-0.3333, 0)和(0.3333, 0)。我们要采样的点(-0.7143, -0.7143)在0, 1, 4, 5中间,所以要从这四点进行采样。根据坐标算出长度比例,然后用bilinear interpolation算出坐标(-0.7143, -0.7143)的值就okay了。

下图是align_corners=True的output

align_corners=False

当align_corners=False时,以坐标(0.7143, -0.7143)为例,请看下图。
注意:这个例子的坐标和上个例子的坐标不一样。

因为align_corners=False, 所以(-1, -1)点的值不为0,(1, 1)点的值也不是15,grid的-1和1不在corner pixel的中心位置,而是在正方形像素的角。所以(-0.25, -0.25)的值才是0, (0.75, 0.75)的值才是15。由此可以推出值为1和2的坐标分别为(-0.25, -0.75)和(0.25, -0.75)。我们要采样的点(0.7143, -0.7143)在2, 3, 6, 7中间,所以要从这四点进行采样。根据坐标算出长度比例,然后用bilinear interpolation算出坐标(0.7143, -0.7143)的值就okay了。

下图是align_corners=False的output:

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/657769.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号