前言:以下仅为个人在学习过程中的记录和总结,有问题欢迎友善讨论。
- 仅以4D input为例(对应二维图片)
torch.nn.functional.grid_sample(input, grid,mode='bilinear', padding_mode='zeros', align_corners=None)
-
Args:
- input
(
N
,
C
,
H
i
n
,
W
i
n
)
(N,C,H_{in}, W_{in})
(N,C,Hin,Win)
- N:batch size
- C:feature dimension
- H_in:input height
- W_in:input width
- grid
(
N
,
H
o
u
t
,
W
o
u
t
,
2
)
(N,H_{out}, W_{out}, 2)
(N,Hout,Wout,2)
- N: batch size
- H_out:output height
- W_out:output width
- 2: 已标准化的采样坐标
- 原文【grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1].】
- align_corners: 是否将input对齐边缘
- input
(
N
,
C
,
H
i
n
,
W
i
n
)
(N,C,H_{in}, W_{in})
(N,C,Hin,Win)
-
Returns:
- Output ( N , C , H o u t , W o u t ) (N,C,H_{out}, W_{out}) (N,C,Hout,Wout)
-
关于双线性插值的理论部分可以参考
-
下面是2个例子帮助理解:
input = torch.tensor([[[[1., 0., 1.],
[2., 2., 3.],
[3., 4., 5.]]]])
grid_x = torch.tensor([-1, 0, 1, 0.5], dtype=torch.float)
grid_y = torch.tensor([-1, 0, 1, 0.5], dtype=torch.float)
grid_x = grid_x[None, None, :, None] # 扩充维度至[1,1,4,1]
grid_y = grid_y[None, None, :, None]
grid = torch.cat((grid_x, grid_y), dim=-1) # [1,1,4,2]
>>
tensor([[[[-1.0000, -1.0000],
[ 0.0000, 0.0000],
[ 1.0000, 1.0000],
[ 0.5000, 0.5000]]]])
output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=True) >> tensor([[[[1.0000, 2.0000, 5.0000, 3.5000]]]]) """ (x坐标向右,y坐标向下) align_corners=True: input 9个点本来对应到3x3的格子的中心,现放射状移动到角点,即9个点对应的坐标值为: [(-1,-1),(0,-1),( 1,-1), (-1, 0),(0, 0),( 1, 0), (-1, 1),(0, 1),( 1, 1)] 于是: (-1,-1):1.0000 (0,0):2.0000 (1,1):5.0000 (0.5,0.5)插值计算(刚好在周围4个点的中心): [[2 3]. [4,5]] (2+3+4+5)/4 = 3.5000 """
output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=False) >> tensor([[[[0.2500, 2.0000, 1.2500, 1.2500, 4.2500]]]]) """ align_corners=False 与之前不同的是,input 9个点直接对应到3x3的格子的中心,对应坐标值是: [(-2/3,-2/3),(0,-2/3),( 2/3,-2/3), (-2/3, 0),(0, 0),( 2/3, 0), (-2/3, 2/3),(0, 2/3),( 2/3, 2/3)] 此时角点的取值与填充方式(padding_mode)有关,默认外围一圈填充0,所以: [[0,0,0,0,0], [0,1,0,1,0], [0,2,2,3,0], [0,3,4,5,0], [0,0,0,0,0]] (-1,-1)角点的值会根据其周围4格进行插值计算,即:(0+0+0+1)/4=0.2500 (0,0)还是在中心点, 2.0000 (1,1):(5+0+0+0)/4 = 1.2500 (0.5,0.5)插值计算,只是现在不再是周围4个点的中心【(1/3,1/3)才是】,代入插值公式计算: 2*0.25*0.25 + 3*0.25*0.75 + 4*0.75*0.25 + 5*0.75*0.75 = 4.25 0.75 = (0.5)/(2/3),即(0.5,0.5)处于3/4比例位置 """
- 小结
- align_corners 影响的是input是否对齐到grid corner上,即(-1,-1)~(1,1)
- 后续的结果其实都是双线性插值的计算



