栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

torch.nn.functional.grid

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

torch.nn.functional.grid

前言:以下仅为个人在学习过程中的记录和总结,有问题欢迎友善讨论。

  • 仅以4D input为例(对应二维图片)
torch.nn.functional.grid_sample(input, grid,mode='bilinear', padding_mode='zeros', align_corners=None)
  • Args:

    • input ( N , C , H i n , W i n ) (N,C,H_{in}, W_{in}) (N,C,Hin​,Win​)
      • N:batch size
      • C:feature dimension
      • H_in:input height
      • W_in:input width
    • grid ( N , H o u t , W o u t , 2 ) (N,H_{out}, W_{out}, 2) (N,Hout​,Wout​,2)
      • N: batch size
      • H_out:output height
      • W_out:output width
      • 2: 已标准化的采样坐标
        • 原文【grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1].】
    • align_corners: 是否将input对齐边缘
  • Returns:

    • Output ( N , C , H o u t , W o u t ) (N,C,H_{out}, W_{out}) (N,C,Hout​,Wout​)
  • 关于双线性插值的理论部分可以参考

  • 下面是2个例子帮助理解:

input = torch.tensor([[[[1., 0., 1.],
                        [2., 2., 3.],
                        [3., 4., 5.]]]])
grid_x = torch.tensor([-1, 0, 1, 0.5], dtype=torch.float)
grid_y = torch.tensor([-1, 0, 1, 0.5], dtype=torch.float)
grid_x = grid_x[None, None, :, None] # 扩充维度至[1,1,4,1]
grid_y = grid_y[None, None, :, None]
grid = torch.cat((grid_x, grid_y), dim=-1) # [1,1,4,2]
>>
tensor([[[[-1.0000, -1.0000],
          [ 0.0000,  0.0000],
          [ 1.0000,  1.0000],
          [ 0.5000,  0.5000]]]])
output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=True)
>>
tensor([[[[1.0000, 2.0000, 5.0000, 3.5000]]]])
"""
(x坐标向右,y坐标向下)
align_corners=True: input 9个点本来对应到3x3的格子的中心,现放射状移动到角点,即9个点对应的坐标值为:
				[(-1,-1),(0,-1),( 1,-1),
				 (-1, 0),(0, 0),( 1, 0),
				 (-1, 1),(0, 1),( 1, 1)]
于是:
(-1,-1):1.0000
(0,0):2.0000
(1,1):5.0000
(0.5,0.5)插值计算(刚好在周围4个点的中心):
				[[2 3].
				 [4,5]]
				 (2+3+4+5)/4 = 3.5000
"""
output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=False)
>>
tensor([[[[0.2500, 2.0000, 1.2500, 1.2500, 4.2500]]]])
"""
align_corners=False
与之前不同的是,input 9个点直接对应到3x3的格子的中心,对应坐标值是:
				[(-2/3,-2/3),(0,-2/3),( 2/3,-2/3),
				 (-2/3,   0),(0,   0),( 2/3,   0),
				 (-2/3, 2/3),(0, 2/3),( 2/3, 2/3)]

此时角点的取值与填充方式(padding_mode)有关,默认外围一圈填充0,所以:
					[[0,0,0,0,0],
					 [0,1,0,1,0],
					 [0,2,2,3,0],
					 [0,3,4,5,0],
					 [0,0,0,0,0]]

(-1,-1)角点的值会根据其周围4格进行插值计算,即:(0+0+0+1)/4=0.2500
(0,0)还是在中心点, 2.0000
(1,1):(5+0+0+0)/4 = 1.2500
(0.5,0.5)插值计算,只是现在不再是周围4个点的中心【(1/3,1/3)才是】,代入插值公式计算:
 	2*0.25*0.25 + 3*0.25*0.75 + 4*0.75*0.25 + 5*0.75*0.75 = 4.25
		0.75 = (0.5)/(2/3),即(0.5,0.5)处于3/4比例位置
"""
  • 小结
    • align_corners 影响的是input是否对齐到grid corner上,即(-1,-1)~(1,1)
    • 后续的结果其实都是双线性插值的计算
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/276223.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号