栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 系统运维 > 运维 > Linux

图像分割的tf-serving

Linux 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

图像分割的tf-serving

记录一下使用tensorflow-serving部署图像分割的过程

一、将h5权重文件转成saved_model可以部署的模型

changeH5tosavedModel.py

import tensorflow as tf
from nets.unet import Unet as unet



if __name__ == '__main__':
    model = unet((512, 512, 3), 2, 'vgg')
    model.load_weights('EP100-loss0.196-valoss0.284.h5')
    tf.saved_model.save(model, "test/1")

    

二、利用docker开启tensorflow serving服务
docker run -p 8501:8501 --mount type=bind,source=E:projectFilesstandardunetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving

gpu(目前只在linux下测试了,因为win10似乎安装不能nvidia-docker):

首先安装必要的东西:

docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi

然后拉取tensorflow-serving gpu镜像:

docker pull tensorflow/serving:latest-gpu

最后开启模型服务

docker run --gpus all -p 8501:8501 --mount type=bind,source=/home/hbli/pythonFiles/unetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving:latest-gpu

MODEL_NAME是自己定的,target最后的unetV1的名字和MODEL_NAME一致,source是被部署的模型所在的文件夹。其他都一样。

三、客户端进行访问

httpClient.py

""" 图像分割的serving """

import cv2
import numpy as np
import requests
import json
import time
from PIL import Image
import colorsys
import matplotlib.pyplot as plt
import os


def resize_image(image, size):
    """ 等比例resize """
    iw, ih  = image.size
    w, h    = size

    scale   = min(w/iw, h/ih)
    nw      = int(iw*scale)
    nh      = int(ih*scale)

    image   = image.resize((nw,nh), Image.BICUBIC)
    new_image = Image.new('RGB', size, (128,128,128))
    new_image.paste(image, ((w-nw)//2, (h-nh)//2))

    return new_image, nw, nh


def preprocess_input(image):
    image = image / 127.5 - 1
    return image


input_shape = (512,512) # 与训练的时候一致
num_classes = 2 # 类别+1


def preProcessing(filepath):
    inputs = cv2.imread(filepath)
    old_img = Image.open(filepath)
    h,w = inputs.shape[0],inputs.shape[1]
    # print(f'初始图像size: {h},{w}')

    """ 数据预处理 """
    image_data, nw, nh  = resize_image(old_img, (input_shape[1], input_shape[0]))
    image_data  = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)

    return old_img,(h,w),(nw,nh),image_data


def mainProcess():
    start = time.time()
    ####--------------------------核心代码----------------------------------------####
    """ REST API端口 """
    url = 'http://localhost:8501/v1/models/unetV1:predict'

    data = json.dumps({'inputs':image_data.tolist()}) # 要求输入的数据是json格式

    response = requests.post(url,data=data)
    result = json.loads(response.content)
    outputs = result['outputs'][0]
    output_array = np.array(outputs) # list转numpy数组
    ####--------------------------核心代码---------------------------------------####

    print(f'花费时间:{time.time()-start:.2f}s')
    # print(type(output_array))
    return output_array



def postProcessing():
    """ 对预测结果进行后处理 """
    # resize回图像原始的大小
    pr = cv2.resize(output_array, (w, h), interpolation = cv2.INTER_LINEAR)
    pr = pr.argmax(axis=-1) # 取出每一个像素点的种类
    seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))

    if num_classes <= 21:
        colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 
                        (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 
                        (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), 
                        (128, 64, 12)]
    else:
        hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)]
        colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))


    for c in range(num_classes):
        seg_img[:,:,0] += ((pr[:,: ] == c )*(colors[c][0] )).astype('uint8')
        seg_img[:,:,1] += ((pr[:,: ] == c )*(colors[c][1] )).astype('uint8')
        seg_img[:,:,2] += ((pr[:,: ] == c )*(colors[c][2] )).astype('uint8')

    resultImage = Image.fromarray(np.uint8(seg_img))
    image = Image.blend(old_img,resultImage,0.7)

    return image


def saveAndShow(image):
    savename = os.path.basename(filepath)[:-4]+"httpResult.jpg"
    savePath = 'servingOut/'
    if not os.path.exists(savePath):
        os.mkdir(savePath)

    image.save(savePath+savename)

    plt.title(os.path.basename(filepath))
    plt.imshow(image)
    plt.show()


 
if __name__ == '__main__':
    while True:
        try:
            filepath = input('请输入待预测图像路径(输入c退出): ')
            if filepath == 'c':
                break        
            old_img,(h,w),(nw,nh),image_data = preProcessing(filepath=filepath)
            output_array = mainProcess()
            image = postProcessing()
            saveAndShow(image)
        except Exception as e:
            print(e)
            continue
        
        

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/730725.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号