Pytorch模型的TF-Serving部署
写在前面
目前PyTorch在学术界几乎完全盖过了Tensorflow,从个人体验来说:
- PyTorch并没有表现出比Tensorflow性能差很多,基本是相当;
- 与Tensorflow的非Eager API相比,PyTorch上手、使用的难度相对来说非常低,并且调试非常容易(可断点可单步);
- 以及笔者还没有体会到的动态图与静态图之争
但是在应用场景下仍然还有将模型使用TF-Serving部署的需求,做转换还是有意义的。
本文将以一个transformer为例来介绍整个流程。
1. PyTorch侧,利用onnx导出
依赖:onnx
model = MultiTaskModel(encoder, sst_head, mrpc_head)
# 关注点1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 关注点2
tensor = torch.tensor((), dtype=torch.int64)
input_ids = tensor.new_ones((1, 48)).to(device)
token_type_ids = tensor.new_ones((1, 48)).to(device)
attention_mask = tensor.new_ones((1, 48)).to(device)
# 关注点3
torch.onnx.export(model,
(input_ids, attention_mask, token_type_ids),