栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

bert-as-service的优化浅析

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

bert-as-service的优化浅析

一、服务部署

使用zeroMQ的Python实现PyZMQ,它提供了轻量级和快速的消息传递实现。C/S消息传递的简单使用示例如下:

import zmq
import zmq.decorators as zmqd

@zmqd.socket(zmq.PUSH)
def send(sock):
    sock.bind('tcp://*:5555')
    sock.send(b'hello')
 
# in another process   
@zmqd.socket(zmq.PULL)
def recv(sock):
    sock.connect('tcp://localhost:5555')
    print(sock.recv())  # shows b'hello'
二、服务加速

服务加速整体架构图如下:

  • freezed: 将动态图转化为静态图,将变量转换为常量,即tf.Variable --> tf.Constant
  • Pruned: 删除图中所有用不到的节点和边
  • Quantized: 将tf.float32转换成tf.float16或者tf.unit8
    tensorflow提供了freezing和pruning的api,只需要定义好输入和输出的节点即可,比如:
input_tensors = [input_ids, input_mask, input_type_ids]
output_tensors = [pooled]
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.graph_util import convert_variables_to_constants

# get graph
tmp_g = tf.get_default_graph().as_graph_def()

sess = tf.Session()
# load parameters then freeze
sess.run(tf.global_variables_initializer())
tmp_g = convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])

# pruning
dtypes = [n.dtype for n in input_tensors]
tmp_g = optimize_for_inference(tmp_g, [n.name[:-2] for n in input_tensors],
    [n.name[:-2] for n in output_tensors],
    [dtype.as_datatype_enum for dtype in dtypes], False)
    
with tf.gfile.GFile('optimized.graph', 'wb') as f:
    f.write(tmp_g.SerializeToString())
三、降低服务延迟

保证一次初始化

def input_fn_builder(sock):
    def gen():
        while True:
            # receive request
            client_id, raw_msg = sock.recv_multipart()
            msg = jsonapi.loads(raw_msg)
            tmp_f = convert_lst_to_features(msg)
            yield {'client_id': client_id,
                   'input_ids': [f.input_ids for f in tmp_f],
                   'input_mask': [f.input_mask for f in tmp_f],
                   'input_type_ids': [f.input_type_ids for f in tmp_f]}

    def input_fn():
        return (tf.data.Dataset.from_generator(gen,
            output_types={'input_ids': tf.int32, 'input_mask': tf.int32, 'input_type_ids': tf.int32, 'client_id': tf.string},
            output_shapes={'client_id': (), 'input_ids': (None, max_seq_len), 'input_mask': (None, max_seq_len),'input_type_ids': (None, max_seq_len)})
                .prefetch(10))
    return input_fn
# initialize BERT model once
estimator = Estimator(model_fn=bert_model_fn)
# keep listen and predict
for result in estimator.predict(input_fn_builder(client), yield_single_examples=False):
    send_back(result)

如果有GPU那么prefetch(10)的操作可以获得10%的加速

四、提高服务的扩展性

假设多个客户端同时向服务器发送请求。 在保证并行化计算之前必须考虑的是服务器应该如何处理接收? 如果它收到第一个请求,保持这个连接直到它发回结果; 然后继续第二个请求? 如果有 100 个客户会发生什么? 服务器是否应该使用相同的逻辑来管理 100 个连接?
考虑另外一个场景,假设有一个客户端每 10 毫秒发送 10K 个句子。 服务器将工作并行化为子任务,并将它们分配给多个 GPU 工作人员。 然后另一个客户端加入,每秒发送一个句子。 这个小批量客户端理论上应该立即得到结果。 不幸的是,由于所有 GPU 工作人员都忙于为第一个客户端进行计算和接收,因此在服务器完成来自第一个客户端的 100 个批次(每个批次有 10K 个句子)之前,第二个客户端将永远不会获得时间段。

当多个客户端连接到一台服务器时,就会出现可扩展性和负载平衡问题。 在 bert-as-service 中,实现了一个push/pull和publish/subscribe sockets的ventilator-worker-sink pipeline。 ventilator的作用类似于批处理调度器和负载平衡器。 它将来自客户端的大请求划分为mini作业。 在将这些mini作业发送给worker之前平衡了它们的负载。 worker从ventilator接收mini作业并进行实际的 BERT 推理,最后将结果发送到接收器(sink)。 接收器(sink)收集所有worker的mini作业的输出。 它检查来自的ventilator所有请求的完整性,并将完整的结果发布给客户端。整体结构如下图所示:

原文链接Serving Google BERT in Production using Tensorflow and ZeroMQ

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/339815.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号