记java重构python版bert-serving-client

背景

项目需要把bert-serving-client由python用java实现,因为java比python快一些,于是就开始了尝试

先上bert-as-service的github地址:https://github.com/hanxiao/bert-as-service

其中client的init.py文件地址:https://github.com/hanxiao/bert-as-service/blob/master/client/bert_serving/client/__init__.py

主要实现其中encode、fetch、fetchAll和encodeAsync

导包

bertClient主要用到zeroMq和json,前者用来提供和服务端的连接,后者格式化传输数据。两者pom依赖如下


 
 
  1. <dependency>
  2. <groupId>org.zeromq </groupId>
  3. <artifactId>jeromq </artifactId>
  4. <version>0.5.1 </version>
  5. </dependency>
  6. <!-- for the latest SNAPSHOT -->
  7. <dependency>
  8. <groupId>org.zeromq </groupId>
  9. <artifactId>jeromq </artifactId>
  10. <version>0.5.2-SNAPSHOT </version>
  11. </dependency>
  12. <dependency>
  13. <groupId>com.google.code.gson </groupId>
  14. <artifactId>gson </artifactId>
  15. <version>2.8.2 </version>
  16. </dependency>
  17. <dependency>
  18. <groupId>org.json </groupId>
  19. <artifactId>json </artifactId>
  20. <version>20180813 </version> <!--注意:20160810版本不支持JSONArray-->
  21. </dependency>

构造方法

python中有默认参数,java里没有,于是我采取属性的默认值+方法重载来实现默认参数。最后java版的构造函数如下:


 
 
  1. private void init() throws Exception {
  2. mContext = new ZContext();
  3. String url = "tcp://" + mIp + ":";
  4. mIdentity = UUID.randomUUID().toString();
  5. mSendSocket = mContext.createSocket(SocketType.PUSH);
  6. mSendSocket.setLinger( 0);
  7. mSendSocket.connect(url + mPort);
  8. mRecvSocket = mContext.createSocket(SocketType.SUB);
  9. mRecvSocket.setLinger( 0);
  10. mRecvSocket.subscribe(mIdentity.getBytes(CHARSET_NAME));
  11. mRecvSocket.connect(url + mPortOut);
  12. }

对应python版的构造函数:


 
 
  1. def __init__(self, ip='localhost', port=5555, port_out=5556,
  2. output_fmt='ndarray', show_server_config=False,
  3. identity=None, check_version=True, check_length=True,
  4. check_token_info=True, ignore_all_checks=False,
  5. timeout=-1):
  6. self.context = zmq.Context()
  7. self.sender = self.context.socket(zmq.PUSH)
  8. self.sender.setsockopt(zmq.LINGER, 0)
  9. self.identity = identity or str(uuid.uuid4()).encode( 'ascii')
  10. self.sender.connect( 'tcp://%s:%d' % (ip, port))
  11. self.receiver = self.context.socket(zmq.SUB)
  12. self.receiver.setsockopt(zmq.LINGER, 0)
  13. self.receiver.setsockopt(zmq.SUBSCRIBE, self.identity)
  14. self.receiver.connect( 'tcp://%s:%d' % (ip, port_out))
  15. ....
  16. ....

收发数据

收发数据对应python版里的_send()和_recv()函数,两者代码如下


 
 
  1. def _send(self, msg, msg_len=0):
  2. self.request_id += 1
  3. self.sender.send_multipart([self.identity, msg, b'%d' % self.request_id, b'%d' % msg_len])
  4. self.pending_request.add(self.request_id)
  5. return self.request_id
  6. def _recv(self, wait_for_req_id=None):
  7. try:
  8. while True:
  9. # a request has been returned and found in pending_response
  10. if wait_for_req_id in self.pending_response:
  11. response = self.pending_response.pop(wait_for_req_id)
  12. return _Response(wait_for_req_id, response)
  13. # receive a response
  14. response = self.receiver.recv_multipart()
  15. request_id = int(response[ -1])
  16. # if not wait for particular response then simply return
  17. if not wait_for_req_id or (wait_for_req_id == request_id):
  18. self.pending_request.remove(request_id)
  19. return _Response(request_id, response)
  20. elif wait_for_req_id != request_id:
  21. self.pending_response[request_id] = response
  22. # wait for the next response
  23. except Exception as e:
  24. raise e
  25. finally:
  26. if wait_for_req_id in self.pending_request:
  27. self.pending_request.remove(wait_for_req_id)

_send()函数里主要调用了发送套接字的send_multipart()函数,把identity、msg、request_id和msg_len作为列表发送过去,java里没有直接对应send_multipart()的方法,可以用sendMore()和send()代替

同样,_recv()函数里主要调用了接收套接字的recv_multipart()函数,java中也没有直接对应的方法,可以用recvMore()代替,最后可以写出java版代码如下


 
 
  1. public long send(String message) {
  2. return send(message, 0);
  3. }
  4. public long send(String message, int messageLen) {
  5. return send( new String[]{message}, messageLen);
  6. }
  7. public long send(String[] message, int messageLen) {
  8. mRequestId++;
  9. Gson gson = new Gson();
  10. sendMultiPart( new String[]{mIdentity, gson.toJson(message), mRequestId + "", messageLen + ""});
  11. mPendingRequest.add(mRequestId);
  12. return mRequestId;
  13. }
  14. private void sendMultiPart(String[] msgParts) {
  15. try {
  16. int i;
  17. for (i = 0; i < msgParts.length - 1; i++) {
  18. mSendSocket.sendMore(msgParts[i].getBytes(CHARSET_NAME));
  19. }
  20. mSendSocket.send(msgParts[i].getBytes(CHARSET_NAME), 0);
  21. } catch (Exception e) {
  22. e.printStackTrace();
  23. }
  24. }
  25. public Map<String, Object> recv() {
  26. return recv( null);
  27. }
  28. public Map<String, Object> recv(Long waitForReqId) {
  29. try {
  30. while ( true) {
  31. if (waitForReqId != null && mPendingResponse.containsKey(waitForReqId)) {
  32. List< byte[]> response = mPendingResponse.get(waitForReqId);
  33. HashMap<String, Object> resultMap = new HashMap<String, Object>();
  34. resultMap.put(KEY_ID, waitForReqId);
  35. resultMap.put(KEY_CONTENT, response);
  36. return resultMap;
  37. }
  38. List< byte[]> response = recvMutipart( 0);
  39. if (response == null || response.size() == 0) {
  40. return null;
  41. }
  42. long requestId = Utils.byte2Long(response.get(response.size() - 1));
  43. if (waitForReqId == null || waitForReqId == requestId) {
  44. mPendingRequest.remove(requestId);
  45. HashMap<String, Object> resultMap = new HashMap<String, Object>();
  46. resultMap.put(KEY_ID, requestId);
  47. if (response != null) {
  48. resultMap.put(KEY_CONTENT, response);
  49. }
  50. return resultMap;
  51. } else if (waitForReqId != requestId) {
  52. mPendingResponse.put(requestId, response);
  53. }
  54. }
  55. } catch (Exception e) {
  56. e.printStackTrace();
  57. } finally {
  58. if (waitForReqId != null && mPendingRequest.contains(waitForReqId)) {
  59. mPendingRequest.remove(waitForReqId);
  60. }
  61. }
  62. return null;
  63. }
  64. private List< byte[]> recvMutipart( int flag) {
  65. ArrayList< byte[]> result = new ArrayList< byte[]>();
  66. byte[] item = mRecvSocket.recv(flag);
  67. if (item != null) {
  68. result.add(item);
  69. }
  70. while (mRecvSocket.hasReceiveMore()) {
  71. item = mRecvSocket.recv(flag);
  72. if (item != null) {
  73. result.add(item);
  74. }
  75. }
  76. return result;
  77. }

注意send()方法中,发送消息时,一定要用gson把消息字符串转换成json格式,否则服务端会报错,客户端收不到数据

自定义的sendMultiPart()方法中,把字符串编码成字节数组时用的编码格式是utf-8,用了自定义常量显示

在接收回复时,根据python版的代码可知,每一个回复的最后一部分是发送消息时对应的请求id,可以采取措施把byte[]数组转换成Long,具体代码如下


 
 
  1. public static long byte2Long(byte[] bytes) {
  2. if (bytes == null) {
  3. return - 1L;
  4. }
  5. StringBuilder builder = new StringBuilder();
  6. for ( int i = 0; i < bytes.length; i++) {
  7. builder.append(bytes[i] - 48);
  8. }
  9. return Long.parseLong(builder.toString());
  10. }

比如收到的字节数组是[49, 48],显然对应的请求id是10,那么由上面的byte2Long方法就可以进行转换

encode

收发数据完成之后,就可以着力实现encode、fetch、fetchAll和encodeAsync四个方法了。

encode编码字符串,在调试python版的客户端后发现,encode编码的字符串必须首先转换成字符串数组,比如"szc"要转换成["s", "z", "c"]。根据这一点,以及python版代码,可以写出java版encode()方法和重载方法


 
 
  1. public List<Object> encode(String text) {
  2. String[] textArray = new String[text.length()];
  3. for ( int i = 0; i < text.length(); i++) {
  4. textArray[i] = "" + text.charAt(i);
  5. }
  6. return encode(textArray, true, false);
  7. }
  8. public List<Object> encode(String text, boolean blocking, boolean showTokens) {
  9. String[] textArray = new String[text.length()];
  10. for ( int i = 0; i < text.length(); i++) {
  11. textArray[i] = "" + text.charAt(i);
  12. }
  13. return encode(textArray, blocking, showTokens);
  14. }
  15. private List<Object> encode(String[] textStr, boolean blocking, boolean showTokens) {
  16. long requestId = send(textStr, textStr.length);
  17. if (!blocking) return null;
  18. Map<String, Object> ndarrayMap = recvNdarray(requestId);
  19. if (ndarrayMap == null) {
  20. return null;
  21. }
  22. List<Float> floatList = (List<Float>) ndarrayMap.get( "embedding");
  23. JSONArray shape = (JSONArray) ndarrayMap.get( "shape");
  24. if (mTokenInfoAvailable && showTokens) {
  25. String token = (String) ndarrayMap.get( "token");
  26. return Arrays.asList(floatList, shape, token);
  27. } else {
  28. return Arrays.asList(floatList, shape, "");
  29. }
  30. }

显然最下面的方法是最终的方法,先把要编码的字符数组发给服务端,获取这次的requestId,然后判断是否阻塞,否的话,说明没必要等这一次编码返回,阻塞的话,则是多次编码之间要串行执行。然后调用recvNdarray()方法获取编码结果,Python版里返回的是一个namedtuple,对应java里的映射。那么我们就来实现这个recvNdarray()

python版的recv_ndarray()方法如下


 
 
  1. def _recv_ndarray(self, wait_for_req_id=None):
  2. request_id, response = self._recv(wait_for_req_id)
  3. arr_info, arr_val = jsonapi.loads(response[ 1]), response[ 2]
  4. X = np.frombuffer(_buffer(arr_val), dtype=str(arr_info[ 'dtype']))
  5. return Response(request_id, self.formatter(X.reshape(arr_info[ 'shape'])), arr_info.get( 'tokens', ''))

首先调用recv()方法获取request_id和response,这也是为什么java版里recv方法返回的是一个映射的原因

然后jsonapi.load()方法其实就是把byte[]数组转换成json字符串,赋值给arr_info;response[2]直接赋给arr_val,然后根据arr_info中dtype的值,把arr_val转换成float数组或列表,最后把request_id、float数组或列表和tokens组成命名元组返回出去。

明白原理后,可以写出Java版代码如下


 
 
  1. public Map<String, Object> recvNdarray(Long waitForReqId) {
  2. HashMap<String, Object> recvMap = (HashMap<String, Object>) recv(waitForReqId);
  3. if (recvMap == null || !recvMap.containsKey(KEY_CONTENT)) {
  4. return null;
  5. }
  6. long requestId = Long.parseLong(String.valueOf(recvMap.get(KEY_ID)));
  7. List< byte[]> content = (List< byte[]>) recvMap.get(KEY_CONTENT);
  8. JSONObject jsonObject = new JSONObject( new String(content.get( 1)));
  9. String type = jsonObject.getString( "dtype");
  10. if (type.contains( "float")) {
  11. HashMap<String, Object> retMap = new HashMap<String, Object>();
  12. retMap.put(KEY_ID, requestId);
  13. retMap.put( "embedding", Utils.byte2float(content.get( 2)));
  14. retMap.put( "tokens", jsonObject.optString( "tokens", " "));
  15. retMap.put( "shape", jsonObject.get( "shape"));
  16. return retMap;
  17. }
  18. return null;
  19. }

编码结果存储在embedding里,这里需要把byte数组转换成float数组。服务端返回的byte数组按小端排序,然后根据float4个字节的大小,可以进行byte数组到float数组的转换


 
 
  1. public static ArrayList<Float> byte2float(byte[] bytes) {
  2. int resultStrLen = bytes.length;
  3. if (resultStrLen % 4 != 0) {
  4. int byteCount = resultStrLen / 4;
  5. int margin = resultStrLen - 4 * byteCount;
  6. if (byteCount > 0) {
  7. bytes = Arrays.copyOfRange(bytes, 0, 4 * byteCount - margin);
  8. }
  9. }
  10. ArrayList<Float> resultArray = new ArrayList<>();
  11. for ( int i = 0; i < bytes.length; i += 4) {
  12. byte[] newBytesFour = Arrays.copyOfRange(bytes, i, i + 4);
  13. resultArray.add(ByteBuffer.wrap(newBytesFour).order(ByteOrder.LITTLE_ENDIAN).getFloat());
  14. }
  15. return resultArray;
  16. }

先把后面不够4字节的去掉,然后按照4:1的比例进行解码,就可以得到浮点数列表。

这样,encode的主要任务就完成了,然后把映射里embedding(也就是浮点数列表)、shape、token作为列表返回到外部,就可以了。

fetch和fetchAll

然后看一下python里的fetch和fetchAll


 
 
  1. def fetch(self, delay=.0):
  2. time.sleep(delay)
  3. while self.pending_request:
  4. yield self._recv_ndarray()
  5. def fetch_all(self, sort=True, concat=False):
  6. if self.pending_request:
  7. tmp = list(self.fetch())
  8. if sort:
  9. tmp = sorted(tmp, key= lambda v: v.id)
  10. tmp = [v.embedding for v in tmp]
  11. if concat:
  12. ....
  13. return tmp

可见,fetch里用到了协程,统一已向服务端发送但没有获取结果的请求获取结果,而fetch_all充其量只是做了一个排序。这样的话,既然是为了实现异步,我们可以用java里的多线程来实现,对应代码如下


 
 
  1. public void fetch(long delay, final IFetchCallback fetchCallback) {
  2. try {
  3. if (delay > 0L) {
  4. Thread.sleep(delay);
  5. }
  6. mExecutorService.submit( new Runnable() {
  7. @Override
  8. public void run() {
  9. ArrayList<Map<String, Object>> fetchResults = new ArrayList<>();
  10. while (mPendingRequest.size() > 0) {
  11. Map<String, Object> recvMap = recvNdarray();
  12. if (recvMap != null) {
  13. fetchResults.add(recvMap);
  14. }
  15. }
  16. if (fetchCallback != null) {
  17. fetchCallback.onFetchResult(fetchResults);
  18. }
  19. }
  20. });
  21. } catch (Exception e) {
  22. e.printStackTrace();
  23. }
  24. }

IFetchCallback是自定义的接口类,用来处理结果

encodeAsync

接下来是异步编码,python版代码如下


 
 
  1. def encode_async(self, batch_generator, max_num_batch=None, delay=0.1, **kwargs):
  2. def run():
  3. cnt = 0
  4. for texts in batch_generator:
  5. self.encode(texts, blocking= False, **kwargs)
  6. cnt += 1
  7. if max_num_batch and cnt == max_num_batch:
  8. break
  9. t = threading.Thread(target=run)
  10. t.start()
  11. return self.fetch(delay)

用协程+多线程实现异步编码。batch_generator可以看成是一批待编码的字符串,也就是字符串数组,然后启动子线程遍历字符串数组,采用非阻塞方式编码,根据上面的python版encode函数,可以看到其实就是只发送数据,不接收结果,结果在服务端保存。最后调用fetch()方法统一获取编码结果,返回出去。

同样采取线程池的方法实现之


 
 
  1. public void encodeAsync(final String[] texts, final boolean blocking, final boolean showTokens
  2. , final long delay, final IEncodeResult encodeCallback, final IFetchCallback fetchCallback) {
  3. try {
  4. mExecutorService.submit( new Runnable() {
  5. @Override
  6. public void run() {
  7. List<List<Object>> encodeResults = new ArrayList<>();
  8. for ( int i = 0; i < texts.length; i++) {
  9. List<Object> eachResult = encode(texts[i], blocking, showTokens);
  10. if (eachResult != null) {
  11. encodeResults.add(eachResult);
  12. }
  13. }
  14. if (encodeCallback != null) {
  15. encodeCallback.onEncodeResult(encodeResults);
  16. }
  17. }
  18. });
  19. fetch(delay, fetchCallback);
  20. } catch (Exception e) {
  21. e.printStackTrace();
  22. }
  23. }

IEncodeCallback也是自定义接口类,负责输出编码结果。

测试

先测试能否编码成正确的浮点数列表。

编码相同的字符串“szc”,看一下python版和java版的embedding结果:

python:

java:

可见,数组大小和内容完全一样,编码功能实现。

再测试能否正确获取没有获取数据的请求,先把"szc"发三遍,不接收,再fetch_all或fetch,看看数组大小对不对即可

python版

java版

java版返回了三个映射,一个映射里有一个大小为2304的结果列表;而python版直接返回了大小为6912的ndarray,大小对的上,说明异步获取结果也实现了

最后测一下异步编码,同样发三遍szc再fetch,看最后的编码结是否正确

没问题,至此Java重构bertClient就算完成了。

踩过的坑

java.lang.UnsatisfiedLinkError: org.zeromq.ZMQ$Socket.nativeInit()V

不要用jzmq,改成jeromq,参见最上面的依赖

结语

这几天重构的过程中,发现python里很多东西就像耍赖一样,比如默认参数、无类型声明、命名元组等,类型变化防不胜防,但随之而来的是运行速度的下降,或许这就是失之东隅,收之桑榆吧。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值