导入头文件:
from torchvision import transforms
transforms.ToTensor()
ToTensor的作用是将导入的图片转换为Tensor的格式,导入的图片为PIL image 或者 numpy.nadrry格式的图片,其shape为(HxWxC)数值范围在[0,255],转换之后shape为(CxHxw),数值范围在[0,1].
transforms.Normalize()
其作用是将图片在每个通道上做标准化处理,即将每个通道上的特征减去均值,再除以方差。
由于 transforms.Normalize()不支持PIL image 格式,所以必须要将图像转换为Tensor格式之后在做标准化处理,做标准化处理的好处在于可以加快网络的收敛,以及提高数据的鲁棒性,防止数据在后续处理中失活,也可以很好的响应激活函数的处理。
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
我们经常看到的这些均值与方差,是从ImageNet数据集上的百万张图片中随机抽样计算得到的。



