torch.tensor(data, dtype=None, device=0, requires_grad=False)
torch.tensor()作为一个函数被调用,返回的是Tensor类型,传入的data可以是list、tuple、scalar、np.array等
究竟是LongTensor、FloatTensor、DoubleTensor根据传入的data来
>>> b = torch.tensor([1,2]) >>> b.dtype torch.int64 >>> c = torch.tensor([1., 2.]) >>> c.dtype torch.float32
Tensor:
这个主要是类,默认是单精度
transpose每次只能指定两个维度交换
torch.transpose(input, dim0, dim1) → Tensor
用transpose可能会带来的问题
>>> a = torch.arange(24).reshape(2,3,4) >>> a.shape torch.Size([2, 3, 4]) >>> b = torch.transpose(a,0,1) >>> b.shape torch.Size([3, 2, 4]) >>> c = b.view(2,3,4) Traceback (most recent call last): File "", line 1, in RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. >>> c tensor([1., 2.]) # 用reshape改变视图或者用.contiguous(),让存储数据连续 >>> c = b.reshape(4,6) >>> b = b.contiguous() >>> d = b.view(4,6) >>> c tensor([[ 0, 1, 2, 3, 12, 13], [14, 15, 4, 5, 6, 7], [16, 17, 18, 19, 8, 9], [10, 11, 20, 21, 22, 23]]) >>> b tensor([[[ 0, 1, 2, 3], [12, 13, 14, 15]], [[ 4, 5, 6, 7], [16, 17, 18, 19]], [[ 8, 9, 10, 11], [20, 21, 22, 23]]]) >>> d tensor([[ 0, 1, 2, 3, 12, 13], [14, 15, 4, 5, 6, 7], [16, 17, 18, 19, 8, 9], [10, 11, 20, 21, 22, 23]]) >>>
permute可以一次奖Tensor变换到任意的dimension sequence的排列
tensor.permute(*dims) → Tensor
参考:https://blog.csdn.net/qq_50001789/article/details/120451717
3. ToTensorToTensor实际开发中经常用到toTensor操作,把读进来的图片直接编程tensor,非常好用
from torchvision.transforms import Compose, ToTensor
def transform():
return Compose([
ToTensor(),
])
class DatasetFromFoldereval(data.Dataset):
def __init__(self, hr_dir, lr_dir, upscale_factor, transform=None):
super(DatasetFromFoldereval, self).__init__()
self.hr_dir = hr_dir
self.lr_dir = lr_dir
self.HR_img = sorted(os.listdir(self.hr_dir))
self.LR_img = sorted(os.listdir(self.lr_dir))
self.hr_filenames = [join(self.hr_dir, x) for x in self.HR_img if
is_image_file(x)]
self.lr_filenames = [join(self.lr_dir, x) for x in self.LR_img if
is_image_file(x)]
self.upscale_factor = upscale_factor
self.transform = transform
def __getitem__(self, index):
input = load_img(self.lr_filenames[index])
_, file = os.path.split(self.lr_filenames[index])
lr_prefix, lr_postfix = file.split('.')
hr_prefix = lr_prefix[0:-2] + 'HR'
hrfile = hr_prefix + "." + lr_postfix
bicubic = rescale_img(input, self.upscale_factor)
hr_path = os.path.join(self.hr_dir, hrfile)
# print('hr_path = ', hr_path)
target = load_img(hr_path)
# print('target.shape ', target.size)
if self.transform:
input = self.transform(input)
bicubic = self.transform(bicubic)
target = self.transform(target)
# print('input shape ----- ', input.shape)
return input, target, bicubic, file
def __len__(self):
return len(self.lr_filenames)
ToTensor实际调用的是functional中的to_tensor
def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == '1':
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
参考:https://www.cnblogs.com/ocean1100/p/9494640.html



