栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

使用Pytorch自带模型预测图片

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用Pytorch自带模型预测图片

如下是博主使用内置模型vgg16对一张图片(拿的VOC2007数据集里面的图片,按照道理应该拿imagenet的图片,如下mean和std也是用的imagenet数据集上的统计)进行预测的代码

import torch
import torchvision
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import matplotlib.pyplot as plt

vgg16 = torchvision.models.vgg16(pretrained=True).cuda()

#
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(), normalize]
            )
img = Image.open("2008_002682.jpg")
print(img.size)

#对图像进行归一化
img_p = transform(img)
print(img_p.shape)

#增加一个维度
img_normalize = torch.unsqueeze(img_p,0).cuda()
print(img_normalize.shape)

vgg16.eval()

out = vgg16(img_normalize)

#最后一层是1000的一维向量,每一个表示对应类别的概率
print(out.shape)

with open('imagenet_classes.txt') as f:
    classes = [line.strip() for line in f.readlines()]

_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
prediction = [[classes[idx], percentage[idx].item()] for idx in indices[0][:5]]
print(prediction)

score = []
label = []
for i in prediction:
    print('Prediciton-> {:<25} Accuracy-> ({:.2f}%)'.format(i[0][:], i[1]))
    score.append(i[1])
    label.append(i[0])

print(score)

#把结果show出来,一些用法和matlab很相似
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 8))
fig.sca(ax1)
ax1.imshow(img)
plt.xticks([])
plt.yticks([])

barlist = ax2.bar(range(5), [i for i in score])
barlist[0].set_color('g')
plt.sca(ax2)
plt.ylim([0, 20])

plt.xticks(range(5),
           # [idx2labels[str(i)][1] for i in pred_label_idx],
           [i for i in label],
           rotation='45')
# fig.subplots_adjust(bottom=0.2)
plt.rcParams['font.size'] = '16'  # 设置字体大小
plt.rcParams['axes.unicode_minus'] = False   # 解决坐标轴负数的负号显示问题
plt.show()

 所用的imagenet_classes.txt可从网址进行下载,测试图片样子如下:

 预测结果如下(可知结果是正确的):

 对比博主之前的博文

深度学习平台实现Demo(八) - c#调用python方式完成训练和预测_jiugeshao的专栏-CSDN博客https://blog.csdn.net/jiugeshao/article/details/112093981该博文是keras框架实现了一个预测,对比下来,大概的过程类似,方法差不多。

附:

 pytorch自带了大量内置模型,相关介绍可见如下博客

pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式_shuijinghua的博客-CSDN博客_pth和pt

 Pytorch的内置模型_博客-CSDN博客_pytorch内置模型

pytorch框架--网络方面--pytorch自带模型(增、改)_雪剑封心-CSDN博客

pytorch 如何调用cuda_将Pytorch模型从CPU转换成GPU的实现方法_扎波罗热人的博客-CSDN博客

 Pytorch 高效使用GPU的操作 - 南鹤- - 博客园
pytorch提供的网络模型(预测图片类别)_z1139269312的博客-CSDN博客

 pytorch下一些常用的操作可见如下代码:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:Icecream.Shao
import torch
import torchvision
from torch import nn

print(torch.cuda.is_available()) #判断是否支持cuda
print(torch.cuda.device_count()) #当前支持cuda的硬件个数
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #选择一个gpu
print(device)
used_gpu_name = torch.cuda.get_device_name(device) #获取所选择的gpu名字
print(used_gpu_name)

#print的常用语法
n=3
print('The squre of',n,'is',n*n)
print('The squre of ' + str(n) + ' is ' + str(n*n))
print('The squre of %s is %s' % (n,n*n))
print('The squre of {1} is {0}'.format(n*n, n))
print(f'model cost:{0.3:.3f}s')

#内置模型的加载方法
#vgg16 = torchvision.models.vgg16(pretrained=True).cuda()
vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
print(vgg16)

#内置数据集的获取方法
train_data = torchvision.datasets.CIFAR10("./data", train=True,transform=torchvision.transforms.ToTensor,download=True)

#增加层以及修改层参数
vgg16.classifier.add_module('my_linear', nn.Linear(1000, 10))
vgg16.classifier[7] = nn.Linear(1000,2)
print(vgg16)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/755492.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号