栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

TorchVision官方文档翻译为中文-示例库Tensor转换与JIT-002

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

TorchVision官方文档翻译为中文-示例库Tensor转换与JIT-002

此示例演示了张量图像上的图像变换现在支持的各种功能。特别是,我们展示了如何在GPU上执行图像转换,以及如何使用JIT编译编写它们的脚本。

在v0.8.0之前,torchvision中的转换传统上是以PIL为中心的,因此存在多个限制。现在,从v0.8.0开始,转换实现与Tensor和PIL兼容,我们可以实现以下新功能:

变换多波段torch张量图像(具有3-4个以上通道)
torchscript与用于部署的模型一起进行转换
支持GPU加速
批处理转换,如视频
直接以torchscript支持的torch张量读取和解码数据(用于PNG和JPEG图像格式)
(注意:这些特征仅适用于张量图像)

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T
from torchvision.io import read_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)

def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

read_image()函数的作用是:读取图像并将其作为张量直接加载

dog1 = read_image(str(Path('assets') / 'dog1.jpg'))
dog2 = read_image(str(Path('assets') / 'dog2.jpg'))
show([dog1, dog2])

在GPU上转换图像
大多数变换支持PIL图像顶部的张量(要可视化变换的效果,请参阅此链接)。使用张量图像,如果cuda可用,我们可以在GPU上运行变换!

import torch.nn as nn

transforms = torch.nn.Sequential(
    T.RandomCrop(224),
    T.RandomHorizontalFlip(p=0.3),
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dog1 = dog1.to(device)
dog2 = dog2.to(device)

transformed_dog1 = transforms(dog1)
transformed_dog2 = transforms(dog2)
show([transformed_dog1, transformed_dog2])


可编写脚本的转换,便于通过torchscript进行部署
现在,我们将展示如何结合图像转换和模型正向传递,同时使用torch.jit.script获得单个脚本化模块。
让我们定义一个预测模块,该模块转换输入张量,然后对其应用ImageNet模型。

from torchvision.models import resnet18


class Predictor(nn.Module):

    def __init__(self):
        super().__init__()
        self.resnet18 = resnet18(pretrained=True, progress=False).eval()
        self.transforms = nn.Sequential(
            T.Resize([256, ]),  #由于torchscript类型的限制,我们在列表中使用单个int值
            T.CenterCrop(224),
            T.ConvertImageDtype(torch.float),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
            y_pred = self.resnet18(x)
            return y_pred.argmax(dim=1)

现在,让我们定义预测器的脚本和非脚本实例,并将其应用于相同大小的多个张量图像。

predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)

batch = torch.stack([dog1, dog2]).to(device)

res = predictor(batch)
res_scripted = scripted_predictor(batch)

输出

Downloading: “https://download.pytorch.org/models/resnet18-f37072fd.pth” to /home/matti/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
/home/matti/miniconda3/envs/pytorch-test/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448216815/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

我们可以验证脚本模型和非脚本模型的预测是相同的:

import json

with open(Path('assets') / 'imagenet_class_index.json', 'r') as labels_file:
    labels = json.load(labels_file)

for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
    assert pred == pred_scripted
    print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")

输出

Prediction for Dog 1: ['n02113023', 'Pembroke']
Prediction for Dog 2: ['n02106662', 'German_shepherd']

由于模型是脚本化的,因此可以很容易地转储到磁盘上并重新使用

import tempfile

with tempfile.NamedTemporaryFile() as f:
    scripted_predictor.save(f.name)

    dumped_scripted_predictor = torch.jit.load(f.name)
    res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all()

更多计算机视觉与图形学相关资料,请关注微信公众号:计算机视觉与图形学实战

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/294135.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号