在run_classifier.py文件中添加函数serving_input_fn
def serving_input_fn():
# 保存模型为SaveModel格式
# 采用最原始的feature方式,输入是feature Tensors。
# 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples
label_ids = tf.placeholder(tf.int32, [None], name='label_ids') # 要素识别任务有2个类别
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
})()
return input_fn
在main函数中修改验证代码
if FLAGS.do_eval:
estimator._export_to_tpu = False
estimator.export_savedmodel(FLAGS.trans_model_dir, serving_input_fn)
2. docker部署bert模型
- 查看centos的版本
cat /etc/redhat-release
- 安装docker
yum install docker
- 启动docker
service docker start
- 查看docker安装的镜像
docker images
- 安装tensorflow-serving镜像
docker pull tensorflow/serving
- 在docker上启动镜像生成容器
docker run -t --rm -p 5000:8501 -v /opt/fuzzy_model/output:/models/bert-model -e MODEL_NAME=bert-model tensorflow/serving &
目录结构如下:
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 6 14:41:13 2020
@author:
"""
import json
import os
import requests
from fuzzy_query import tokenization
import numpy as np
def get_work_dir():
return os.path.dirname(
os.path.dirname(
os.path.dirname(
os.path.dirname(__file__)
)
)
)
vocal_file_path = 'vocab.txt'
# vocal_file_path = os.path.join(get_work_dir(), 'var', 'vocab.txt')
tokenizer = tokenization.FullTokenizer(vocab_file=vocal_file_path,
do_lower_case=True)
max_seq_length = 64
def text2ids(text_list):
input_ids_list = []
input_mask_list = []
label_ids_list = []
segment_ids_list = []
for text in text_list:
label = 0
tokens_a = tokenizer.tokenize(text)
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
input_ids_list.append(input_ids)
input_mask_list.append(input_mask)
label_ids_list.append(label)
segment_ids_list.append(segment_ids)
return input_ids_list, input_mask_list, label_ids_list, segment_ids_list
def bert_class(textList):
"""
调用tfserving服务的接口,对外提供服务
:parma textList: 输入的文本列表
:return result: 结果
"""
input_ids_list, input_mask_list, label_ids_list, segment_ids_list =
text2ids(textList)
data = json.dumps(
{
"name": 'bert',
"signature_name": 'serving_default',
"inputs": {
'input_ids': input_ids_list,
'input_mask': input_mask_list,
'label_ids': label_ids_list,
'segment_ids': segment_ids_list
}})
headers = {"content-type": "application/json"}
url = 'http://10.30.239.205:5000/v1/models/bert-model:predict'
json_response = requests.post(url, data=data, headers=headers)
predictions = json.loads(json_response.text)['outputs']
pre_list = [np.argmax(result) for result in predictions]
print(pre_list)
return pre_list
if __name__ == "__main__":
name_list = ['饰品']
bert_class(name_list)



