通过FastAPI框架部署Torchvision.Models预训练模型进行图像识别。
PyTorch框架中有一个非常重要的包:torchvision,它由3个子包组成,分别是:
- torchvision.datasets
- torchvision.models
- torchvision.transforms
其中torchvision.models中包含了很多预训练模型,可以直接使用。由于国内的网络环境,可以通过coggle.club手动下载预训练模型镜像。
通常使用Flask框架为预训练模型创建API服务,但如果想做一个满足高并发的机器学习API服务,异步框架FastAPI是一个不错的选择。
相比Flask,FastAPI框架具有以下几大功能:
- 异步web框架,支持asyncio;
- 拥有非常高的性能(归功于Starlette和Pydantic);
- 通过不同的参数声明实现丰富的功能;
- 自动类型检查,自动生成交互式文档,自动swagger UI;
- 自带快如闪电的异步服务器Uvicorn;
- 加载所需包
import io import json from PIL import Image from torchvision import models import torchvision.transforms as transforms from fastapi import FastAPI, File, UploadFile import uvicorn
- 初始化
app = FastAPI()
# 加载预训练模型
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
- 自定义函数
# 图片文件读取,输出Image.Image格式
def read_imagefile(file) -> Image.Image:
image = Image.open(io.BytesIO(file))
return image
# 图片预处理,torchvision.transforms转换Image格式为torch tensor
def transform_image(image_bytes: Image.Image):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
return my_transforms(image_bytes).unsqueeze(0)
# 定义预测函数,图片预处理->模型预测->预测结果转换
def get_prediction(image_bytes: Image.Image):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
- 路由/predict
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
'''
Parameters
----------
file : UploadFile, optional
DEscriptION. The default is an image file.
Returns
-------
json : Response with list of dicts.
Each dict contains class_id, class_name
'''
extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
if not extension:
return "Image must be jpg or png format!"
img_bytes = read_imagefile(await file.read())
class_id, class_name = get_prediction(image_bytes=img_bytes)
return {'class_id': class_id, 'class_name': class_name}
- 运行uvicorn.run
if __name__ == "__main__":
app_str = 'api_server:app'
uvicorn.run(app_str, host='localhost', port=8000, debug=True, reload=True, workers=1)
Coding - api_test.py
import requests
def test_request(image = 'images/dog.jpg'):
resp = requests.post("http://localhost:8000/predict",
files={"file": open(image,'rb')})
print(resp.json())
if __name__ == '__main__':
test_request()
swagger UI界面
http://127.0.0.1:8000/docs
# Clone the repo $ git clone https://gitee.com/vencen/cv-fast-api.git # 创建虚拟环境,安装依赖包 $ conda create -n venv python=3.8 #on windows $ activate venv #on linux $ source activate venv $ pip install -r requirements.txt使用说明
-
创建.env文件
# .env file example
-
启动虚拟环境
# 启动虚拟环境 #on windows $ activate venv #on linux $ source activate venv
-
启动api
# 启动api $ uvicorn server:app --reload



