使用flask调用训练好的BERT | 原版代码太长了让我随便删点

一、需要环境

Ubuntu18.04 / CUDA10.1 / CUDNN7.6.5 / RTX2080 /ANACONDA3 5.2.0 / TensorFlow-gpu=1.13.1 剩下的看表吧

Package                Version
---------------------- ---------
absl-py                0.11.0
astor                  0.8.1
astunparse             1.6.3
cachetools             4.2.1
certifi                2020.12.5
chardet                4.0.0
click                  7.1.2
coverage               5.5
cycler                 0.10.0
Cython                 0.29.22
Flask                  1.1.2
flatbuffers            1.12
gast                   0.4.0
google-auth            1.27.0
google-auth-oauthlib   0.4.2
google-pasta           0.2.0
grpcio                 1.36.1
h5py                   2.10.0
idna                   2.10
imageio                2.9.0
importlib-metadata     3.7.0
itsdangerous           1.1.0
Jinja2                 2.11.3
Keras-Applications     1.0.8
Keras-Preprocessing    1.1.2
kiwisolver             1.3.1
Markdown               3.3.4
MarkupSafe             1.1.1
matplotlib             3.3.4
mkl-fft                1.3.0
mkl-random             1.1.1
mkl-service            2.3.0
mock                   4.0.3
numpy                  1.19.5
oauthlib               3.1.0
opt-einsum             3.3.0
Pillow                 8.1.2
pip                    21.0.1
protobuf               3.15.3
pyasn1                 0.4.8
pyasn1-modules         0.2.8
pyparsing              2.4.7
python-dateutil        2.8.1
requests               2.25.1
requests-oauthlib      1.3.0
rsa                    4.7.2
scipy                  1.5.2
setuptools             39.1.0
six                    1.15.0
tensorboard            1.13.1
tensorboard-plugin-wit 1.8.0
tensorflow             1.13.1
tensorflow-estimator   2.4.0
termcolor              1.1.0
tqdm                   4.59.0
typing-extensions      3.7.4.3
urllib3                1.26.3
Werkzeug               1.0.1
wheel                  0.36.2
wrapt                  1.12.1
zipp                   3.4.0

 其次需要你能够完整跑通Google-BERT官方项目,如果没有自己训练使用预训练权重也未尝不可。

官方项目下载地址:https://github.com/google-research/bert

注意8G显存的话只能用768的权重,而且bs最大22左右。

其他配置都默认。

二、部署

首先需要把推断(Predict)部分的代码独立出来,我是根据CoLA任务做的迁移训练,那么接下来所有东西都根据CoLA来写。

2.1 第一步是把预定义参数提取出来

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import collections
import csv
import time
import math
import modeling
import optimization
import tokenization
import tensorflow as tf
import tqdm
import json
from flask import Flask, render_template, request, jsonify, make_response
bert_config_file = "./model/bert_config.json"
task_name = "cola"
vocab_file = "./model/vocab.txt"
output_dir = "./tmp/cola_output/"
init_checkpoint = "./model/model.ckpt"
do_lower_case = True
max_seq_length = 128
do_predict = True
train_batch_size = 32
eval_batch_size = 8
predict_batch_size = 8
learning_rate = 5e-5
num_train_epochs = 3.0
warmup_proportion = 0.1
save_checkpoints_steps = 1000
iterations_per_loop = 1000
use_tpu = False
tpu_name = None
tpu_zone = None
gcp_project = None
master = None
num_tpu_cores = 8

那种flag的形式本人不太喜欢,看起来字太多了,一眼找不到自己需要的信息。

2.2 一些数据生成模块

class InputExample(object):
    def __init__(self, guid, text_a, text_b=None, label=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label
class PaddingInputExample(object):
    '''没用'''
class InputFeatures(object):
    def __init__(self,input_ids,input_mask,segment_ids,label_id,is_real_example=True):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id
        self.is_real_example = is_real_example
class DataProcessor(object):
    def get_train_examples(self, data_dir):
        raise NotImplementedError()
    def get_dev_examples(self, data_dir):
        raise NotImplementedError()
    def get_test_examples(self, data_dir):
        raise NotImplementedError()
    def get_labels(self):
        raise NotImplementedError()
    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        with tf.gfile.Open(input_file, "r") as f:
      
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值