运行别人的的源代码还报错,所以确定不是代码问题。
问了别人,应该是fft函数对应的torch版本问题,torch1.8.0版本之后的才是torch.fft.fft2
根据网上的总结自己改的
#旧版 新版 torch.rfft torch.fft.fft2 torch.irfft torch.fft.ifft2
还是报错了,函数中使用的参数定义应该也是不一样的
原来不是版本问题,看到pytorch官网上只有这两种函数,所以猜想是不是我写错了
解决办法,把源代码的torch.fft.fft2改成torch.fft.fftn就可以了def D(x, Dh_DFT, Dv_DFT):
x_DFT = torch.fft.fftn(x, dim=(-2,-1)).cuda()
Dh_x = torch.fft.ifftn(Dh_DFT*x_DFT, dim=(-2,-1)).real
Dv_x = torch.fft.ifftn(Dv_DFT*x_DFT, dim=(-2,-1)).real
return Dh_x, Dv_x
然后代码正常运行啦



