pytorch 中的 unfold 和 fold
inp = torch.randn(1, 16, 6, 6) # [b, in_c, f_h, f_w]
print(inp.shape) # torch.Size([1, 3, 6, 6])
patches = f.unfold(inp, (3, 3))
print(patches.shape) # torch.Size([1, 144, 16])
W = torch.randn(64, 16, 3, 3) # out_c, in_c, k_h, k_w
out_unf = patches.transpose(1, 2).matmul(W.view(W.size(0), -1).t()).transpose(1, 2)
print(out_unf.shape) # torch.Size([1, 64, 16])
# 注意这里是 unfold 的反向操作,所以也是以 kernel_size 的大小做滑窗的
out = f.fold(input=out_unf, output_size=(4, 4), kernel_size=(1, 1))
print(out.shape) # torch.Size([1, 64, 4, 4])