NLP——如何批量加载数据
问题背景:利用Bert、Albert、Roberta或腾讯词向量等预训练模型,去微调或者特征集成各类NLP子任务,并转存为pb模型后,如何进行批量预测,以提升性能,缩短耗时呢?当然是batch预测了!为了方便以后使用,基于苏神的代码,我封装成了一个简单的类。下面的代码虽然短,但需要读者朋友了解以下资料:
资源:
- 源代码:改自苏神的bert4keras源码文件夹examples中的情感分析例子
- 如果不熟悉子类继承父类并初始化,戳一下:python之子类继承父类时进行初始化的一些问题
- 想了解更多Bert应用技术细节的朋友,戳张俊林的知乎文章:Bert时代的创新(应用篇):Bert在NLP各领域的应用进展
- 重点应用:yield,不熟悉的朋友,可以读读这两篇文章:Python yield 使用浅析 和 python中yield的用法详解——最简单,最清晰的解释
- 重点应用:__iter__,不熟悉的朋友,戳一下:Python __iter__ 深入理解 和Python3 生成器(Generator)概念浅析【面试生成器与迭代器的朋友,请仔细阅读】,或者了解一下python的迭代器为什么一定要实现__iter__方法?
- 机器学习算法工程师公众号:带你从零掌握迭代器及构建最简DataLoader
效果:
- 与单条文本预测相比,每个batch的预测速度更快,约k*batch_size倍(k=1或2)
原因:
- 批量预测,相当于只用计算N/batch_size次矩阵乘法(或点积运算),而文本输入单条预测存在两部分耗时:T(数据预处理) + T(向量乘法)
代码示例(直接复制过去,运行即可,懂者自懂):
import os
try:
os.system("pip install -i https://pypi.tuna.tsinghua.edu.cn/simple loguru")
os.system("pip install -i https://pypi.tuna.tsinghua.edu.cn/simple bert4keras")
except:
pass
import numpy as np
from loguru import logger
from bert4keras.backend import set_gelu
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
set_gelu('tanh') # 切换gelu版本
def load_data(filename):
D = []
with open(filename, encoding='utf-8') as f:
for l in f:
text1, text2, label = l.strip().split('\t')
D.append((text1, text2, int(label)))
return D
# 数据生成器类
class data_generator(DataGenerator):
def __init__(self, maxlen, dict_path, data, batch_size, buffer_size):
# 子类继承父类, 并进行初始化
super(data_generator, self).__init__(data=data, batch_size=batch_size, buffer_size=buffer_size)
self.maxlen = maxlen
self.tokenizer = Tokenizer(dict_path, do_lower_case=True)
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for is_end, (text1, text2, label) in self.sample(random):
token_ids, segment_ids = self.tokenizer.encode(text1, text2, maxlen=self.maxlen)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append([label])
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_labels = sequence_padding(batch_labels)
yield [batch_token_ids, batch_segment_ids], batch_labels
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
if __name__ == "__main__":
data_generator = data_generator(maxlen=128,
dict_path="./bert/chinese_L-12_H-768_A-12/vocab.txt",
data=load_data("./test.txt"),
batch_size=32,
buffer_size=None)
for x_true, y_true in data_generator:
# 输出每个batch中的数据
for idx in range(len(x_true[0])):
logger.info("第{0}条文本:".format(idx))
logger.info("word2id: {0}".format(str(list(x_true[0][idx]))))
logger.info("mask: {0}".format(str(list(x_true[1][idx]))))
logger.info("label: {0}\n\n".format(str(list(y_true[idx]))))
运行结果:
2020-10-14 22:28:57.416 | INFO | __main__:<module>:73 - 第0条文本:
2020-10-14 22:28:57.417 | INFO | __main__:<module>:74 - word2id: [101, 6443, 3300, 4312, 676, 6821, 2476, 7770, 3926, 4638, 102, 6821, 2476, 7770, 3926, 1745, 8024, 6443, 3300, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO | __main__:<module>:76 - label: [0]
2020-10-14 22:28:57.417 | INFO | __main__:<module>:73 - 第1条文本:
2020-10-14 22:28:57.417 | INFO | __main__:<module>:74 - word2id: [101, 5739, 7413, 5468, 4673, 784, 720, 5739, 7413, 3297, 1962, 102, 5739, 7413, 5468, 4673, 3297, 1962, 5739, 7413, 3221, 784, 720, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO | __main__:<module>:76 - label: [1]
2020-10-14 22:28:57.418 | INFO | __main__:<module>:73 - 第2条文本:
2020-10-14 22:28:57.418 | INFO | __main__:<module>:74 - word2id: [101, 6821, 3221, 784, 720, 2692, 2590, 8024, 6158, 6701, 5381, 1408, 102, 2769, 738, 3221, 7004, 749, 8024, 6821, 3221, 784, 720, 2692, 2590, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.418 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.418 | INFO | __main__:<module>:76 - label: [0]
2020-10-14 22:28:57.418 | INFO | __main__:<module>:73 - 第3条文本:
2020-10-14 22:28:57.419 | INFO | __main__:<module>:74 - word2id: [101, 4385, 1762, 3300, 784, 720, 1220, 4514, 4275, 1962, 4692, 1450, 8043, 102, 4385, 1762, 3300, 784, 720, 1962, 4692, 4638, 1220, 4514, 4275, 1408, 8043, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.419 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.419 | INFO | __main__:<module>:76 - label: [1]
2020-10-14 22:28:57.419 | INFO | __main__:<module>:73 - 第4条文本:
2020-10-14 22:28:57.419 | INFO | __main__:<module>:74 - word2id: [101, 6435, 7309, 3253, 6809, 4510, 2094, 1322, 4385, 1762, 4638, 2339, 6598, 2521, 6878, 2582, 720, 3416, 6206, 3724, 3300, 1525, 763, 102, 676, 3215, 4510, 2094, 1322, 2339, 6598, 2521, 6878, 2582, 720, 3416, 1557, 102]
2020-10-14 22:28:57.419 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
2020-10-14 22:28:57.419 | INFO | __main__:<module>:76 - label: [0]
2020-10-14 22:28:57.419 | INFO | __main__:<module>:73 - 第5条文本:
2020-10-14 22:28:57.420 | INFO | __main__:<module>:74 - word2id: [101, 3152, 4995, 4696, 4638, 4263, 2001, 5013, 1408, 102, 2001, 5013, 4696, 4638, 6158, 3152, 4995, 2397, 749, 1408, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO | __main__:<module>:76 - label: [0]
2020-10-14 22:28:57.420 | INFO | __main__:<module>:73 - 第6条文本:
2020-10-14 22:28:57.420 | INFO | __main__:<module>:74 - word2id: [101, 6843, 5632, 2346, 976, 4638, 7318, 6057, 784, 720, 4495, 3189, 4851, 4289, 1962, 102, 6843, 7318, 6057, 784, 720, 4495, 3189, 4851, 4289, 1962, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO | __main__:<module>:76 - label: [1]
2020-10-14 22:28:57.420 | INFO | __main__:<module>:73 - 第7条文本:
2020-10-14 22:28:57.420 | INFO | __main__:<module>:74 - word2id: [101, 6818, 3309, 677, 3216, 4638, 4510, 2512, 102, 6818, 3309, 677, 3216, 4638, 4510, 2512, 3300, 1525, 763, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO | __main__:<module>:76 - label: [1]
2020-10-14 22:28:57.421 | INFO | __main__:<module>:73 - 第8条文本:
2020-10-14 22:28:57.421 | INFO | __main__:<module>:74 - word2id: [101, 3724, 5739, 7413, 5468, 4673, 1920, 4868, 2372, 8043, 102, 5739, 7413, 5468, 4673, 8024, 3724, 1920, 4868, 2372, 172, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO | __main__:<module>:76 - label: [1]
2020-10-14 22:28:57.422 | INFO | __main__:<module>:73 - 第9条文本:
2020-10-14 22:28:57.422 | INFO | __main__:<module>:74 - word2id: [101, 1963, 1217, 677, 784, 720, 6956, 7674, 102, 5314, 691, 1217, 677, 6956, 7674, 3221, 784, 720, 2099, 8043, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.422 | INFO | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.422 | INFO | __main__:<module>:76 - label: [0]
说明:
- 代码中的"vocab.txt",做NLP的都知道
- 代码中的"test.txt"文件来源于LCQMC数据中test.txt中的前10条,内容如下(仅供demo使用):
谁有狂三这张高清的 这张高清图,谁有 0
英雄联盟什么英雄最好 英雄联盟最好英雄是什么 1
这是什么意思,被蹭网吗 我也是醉了,这是什么意思 0
现在有什么动画片好看呢? 现在有什么好看的动画片吗? 1
请问晶达电子厂现在的工资待遇怎么样要求有哪些 三星电子厂工资待遇怎么样啊 0
文章真的爱姚笛吗 姚笛真的被文章干了吗 0
送自己做的闺蜜什么生日礼物好 送闺蜜什么生日礼物好 1
近期上映的电影 近期上映的电影有哪些 1
求英雄联盟大神带? 英雄联盟,求大神带~ 1
如加上什么部首 给东加上部首是什么字? 0
如果不是匹配任务,而是分类任务(格式:文本\t标签),需要更改data_generator类中的2行代码,如下:
for is_end, (text1, label) in self.sample(random):
token_ids, segment_ids = self.tokenizer.encode(text1, maxlen=self.maxlen)
其他任务类似处理,结束撒花🎉🎉🎉。分享一部口碑不错的动漫:《灵笼》