使用zeroMQ的Python实现PyZMQ,它提供了轻量级和快速的消息传递实现。C/S消息传递的简单使用示例如下:
import zmq
import zmq.decorators as zmqd
@zmqd.socket(zmq.PUSH)
def send(sock):
sock.bind('tcp://*:5555')
sock.send(b'hello')
# in another process
@zmqd.socket(zmq.PULL)
def recv(sock):
sock.connect('tcp://localhost:5555')
print(sock.recv()) # shows b'hello'
二、服务加速
服务加速整体架构图如下:
- freezed: 将动态图转化为静态图,将变量转换为常量,即tf.Variable --> tf.Constant
- Pruned: 删除图中所有用不到的节点和边
- Quantized: 将tf.float32转换成tf.float16或者tf.unit8
tensorflow提供了freezing和pruning的api,只需要定义好输入和输出的节点即可,比如:
input_tensors = [input_ids, input_mask, input_type_ids] output_tensors = [pooled]
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.graph_util import convert_variables_to_constants
# get graph
tmp_g = tf.get_default_graph().as_graph_def()
sess = tf.Session()
# load parameters then freeze
sess.run(tf.global_variables_initializer())
tmp_g = convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
# pruning
dtypes = [n.dtype for n in input_tensors]
tmp_g = optimize_for_inference(tmp_g, [n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes], False)
with tf.gfile.GFile('optimized.graph', 'wb') as f:
f.write(tmp_g.SerializeToString())
三、降低服务延迟
保证一次初始化
def input_fn_builder(sock):
def gen():
while True:
# receive request
client_id, raw_msg = sock.recv_multipart()
msg = jsonapi.loads(raw_msg)
tmp_f = convert_lst_to_features(msg)
yield {'client_id': client_id,
'input_ids': [f.input_ids for f in tmp_f],
'input_mask': [f.input_mask for f in tmp_f],
'input_type_ids': [f.input_type_ids for f in tmp_f]}
def input_fn():
return (tf.data.Dataset.from_generator(gen,
output_types={'input_ids': tf.int32, 'input_mask': tf.int32, 'input_type_ids': tf.int32, 'client_id': tf.string},
output_shapes={'client_id': (), 'input_ids': (None, max_seq_len), 'input_mask': (None, max_seq_len),'input_type_ids': (None, max_seq_len)})
.prefetch(10))
return input_fn
# initialize BERT model once
estimator = Estimator(model_fn=bert_model_fn)
# keep listen and predict
for result in estimator.predict(input_fn_builder(client), yield_single_examples=False):
send_back(result)
如果有GPU那么prefetch(10)的操作可以获得10%的加速
四、提高服务的扩展性假设多个客户端同时向服务器发送请求。 在保证并行化计算之前必须考虑的是服务器应该如何处理接收? 如果它收到第一个请求,保持这个连接直到它发回结果; 然后继续第二个请求? 如果有 100 个客户会发生什么? 服务器是否应该使用相同的逻辑来管理 100 个连接?
考虑另外一个场景,假设有一个客户端每 10 毫秒发送 10K 个句子。 服务器将工作并行化为子任务,并将它们分配给多个 GPU 工作人员。 然后另一个客户端加入,每秒发送一个句子。 这个小批量客户端理论上应该立即得到结果。 不幸的是,由于所有 GPU 工作人员都忙于为第一个客户端进行计算和接收,因此在服务器完成来自第一个客户端的 100 个批次(每个批次有 10K 个句子)之前,第二个客户端将永远不会获得时间段。
当多个客户端连接到一台服务器时,就会出现可扩展性和负载平衡问题。 在 bert-as-service 中,实现了一个push/pull和publish/subscribe sockets的ventilator-worker-sink pipeline。 ventilator的作用类似于批处理调度器和负载平衡器。 它将来自客户端的大请求划分为mini作业。 在将这些mini作业发送给worker之前平衡了它们的负载。 worker从ventilator接收mini作业并进行实际的 BERT 推理,最后将结果发送到接收器(sink)。 接收器(sink)收集所有worker的mini作业的输出。 它检查来自的ventilator所有请求的完整性,并将完整的结果发布给客户端。整体结构如下图所示:
原文链接Serving Google BERT in Production using Tensorflow and ZeroMQ



