栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

通过FastAPI框架部署Torchvision预训练模型

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

通过FastAPI框架部署Torchvision预训练模型

通过FastAPI框架部署Torchvision预训练模型 介绍

通过FastAPI框架部署Torchvision.Models预训练模型进行图像识别。

PyTorch框架中有一个非常重要的包:torchvision,它由3个子包组成,分别是:

  • torchvision.datasets
  • torchvision.models
  • torchvision.transforms

其中torchvision.models中包含了很多预训练模型,可以直接使用。由于国内的网络环境,可以通过coggle.club手动下载预训练模型镜像。

通常使用Flask框架为预训练模型创建API服务,但如果想做一个满足高并发的机器学习API服务,异步框架FastAPI是一个不错的选择。

相比Flask,FastAPI框架具有以下几大功能:

  • 异步web框架,支持asyncio;
  • 拥有非常高的性能(归功于Starlette和Pydantic);
  • 通过不同的参数声明实现丰富的功能;
  • 自动类型检查,自动生成交互式文档,自动swagger UI;
  • 自带快如闪电的异步服务器Uvicorn;
Coding - api_server.py
  1. 加载所需包
import io
import json
from PIL import Image
from torchvision import models
import torchvision.transforms as transforms
from fastapi import FastAPI, File, UploadFile
import uvicorn
  1. 初始化
app = FastAPI()
# 加载预训练模型
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
  1. 自定义函数
# 图片文件读取,输出Image.Image格式
def read_imagefile(file) -> Image.Image:
    image = Image.open(io.BytesIO(file))
    return image

# 图片预处理,torchvision.transforms转换Image格式为torch tensor
def transform_image(image_bytes: Image.Image):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    return my_transforms(image_bytes).unsqueeze(0)

# 定义预测函数,图片预处理->模型预测->预测结果转换
def get_prediction(image_bytes: Image.Image):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]
  1. 路由/predict
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
    '''
    Parameters
    ----------
    file : UploadFile, optional
        DEscriptION. The default is an image file.

    Returns
    -------
    json : Response with list of dicts.
		Each dict contains class_id, class_name

    '''
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"

    img_bytes = read_imagefile(await file.read())
    class_id, class_name = get_prediction(image_bytes=img_bytes)
    return {'class_id': class_id, 'class_name': class_name}
  1. 运行uvicorn.run
if __name__ == "__main__":
    app_str = 'api_server:app'
    uvicorn.run(app_str, host='localhost', port=8000, debug=True, reload=True, workers=1)
Coding - api_test.py

import requests

def test_request(image = 'images/dog.jpg'):
    resp = requests.post("http://localhost:8000/predict",
                         files={"file": open(image,'rb')})
    
    print(resp.json())

if __name__ == '__main__':
	test_request()
swagger UI界面

http://127.0.0.1:8000/docs

安装教程
# Clone the repo
$ git clone https://gitee.com/vencen/cv-fast-api.git
# 创建虚拟环境,安装依赖包

$ conda create -n venv python=3.8

#on windows
$ activate venv
#on linux
$ source activate venv

$ pip install -r requirements.txt
使用说明
  1. 创建.env文件

    # .env file example
    
    
  2. 启动虚拟环境

    # 启动虚拟环境
    #on windows
    $ activate venv
    #on linux
    $ source activate venv
    
  3. 启动api

    # 启动api
    $ uvicorn server:app --reload
    
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/342219.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号