主要参考:Link
1. 导入torch并查看其版本import torch print(torch.__version__)2. 随机种子
def set_up(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def train():
set_up(2021) //这里的2021可为任意整数
// ...
3. 查看张量的基本信息
调试代码时最多的就是查看张量的形状和维度信息
tensor = torch.zeros(2, 3, 4) print(tensor.size()) # 查看张量的形状 print(tensor.dim()) # 查看张量的维度4. torch与numpy转换
一般将ndarray转换为tensor比较多,因为将tensor转换到ndarray之后,运算会在cpu上运行,会大大降低运行速度。
tensor = torch.zeros(2, 3, 4) np = tensor.cpu().numpy() tensor = torch.from_numpy(np).float()5. numpy数组转换为图像
常用于可视化,或加载图片数据
iamge = PIL.Image.fromarray(ndarray.astype(np.unit8)) //numpy数组转Image图像 ndarray = np.asarray(PIL.Image.open(path)) //Image图像转numpy数组6. 张量拼接
cat:在给定维度上对输入的张量序列seq进行连接操作,所有的tensors必须为相同的shape或者为空。
x = torch.randn(2, 3) torch.cat((x, x, x), 0) torch.cat((x, x, x), 1)
stack:沿着新的维度拼接一个序列的tensors
torch.stack(tensors, dim=0, *, out=None) → Tensor7.矩阵乘法
torch = torch.mm(mat1, mat2)8. 模型定义
class ConvNet(nn.MOdule):
def __init__(ConvNet, num_classes=10):
super(ConvNet, self).__init__()
...
def forward(self, x):
...
model = ConvNet(number_classes).to(device) // 调用该类时即刻自动调用forward



