【roformer_sim】MASK模型发散生成
MASK模型发散生成
苏神在roformer_sim文章最后讲了个小玩具,“科学是第一生产力”mask掉“第一”,模型可发散生成“科学是第二生产力”,正好做文成生成需要,源码中没有,求助大佬后实现功能,现将做法写于此,旨在帮助不想看源码的同学们快速实现功能。
苏神原文链接:https://kexue.fm/archives/8454
源码地址:https://github.com/ZhuiyiTechnology/roformer-sim
原代码
clone 下来后,下载模型文件,解压到路径,在test中的 generate.py 文件中更改PRETRAINED变量为指定路径
上述操作完成后,就能run起来了~
让我们来看下代码:
class SynonymsGenerator(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, steps):
token_ids, segment_ids = inputs
token_ids = np.concatenate([token_ids, output_ids], 1)
segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
return self.last_token(seq2seq).predict([token_ids, segment_ids])
def generate(self, text, n=1, topp=0.95, mask_id=None):
token_ids, segment_ids = tokenizer.encode(text, maxlen=MAXLEN)
output_ids = self.random_sample([token_ids, segment_ids], n,
topp=topp) # 基于随机采样
return [tokenizer.decode(ids) for ids in output_ids]
不难发现,其需要的编码有两个,一个为token_ids 对应每个字的id,一个为segment_ids,对应为句子id,长度为句子字数+[cls]+[sep]=句子字数+2,使用阶段为单句生成,那么segment就都为0。
此外原代码使用的是tokenizer中的函数encoder生成的这两种id映射。因此我们只需要更改token_ids中的id为MASK的id就能实现“科学技术是[MASK][MASK]生产力的效果”。
更改代码
先看下更改后的代码,为了保持原有功能,我们添加一个masked_encode函数(这里需注意,为个人写的函数,并不是stage中的masked_encode函数)。
def masked_encode(text, mask_id=None):
token_ids, segment_ids = tokenizer.encode(text, maxlen=MAXLEN)
# print('有mask')
for i in mask_id:
token_ids[i] = tokenizer._token_mask_id
return token_ids, segment_ids
从代码中可看出,tokenizer._token_mask_id即为[MASK]对应的id。再对类函数进行更改。
class SynonymsGenerator(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, steps):
token_ids, segment_ids = inputs
token_ids = np.concatenate([token_ids, output_ids], 1)
segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
return self.last_token(seq2seq).predict([token_ids, segment_ids])
def generate(self, text, n=1, topp=0.95, mask_id=None):
# 自改, 把tokenizer._token_mask加入进去
if mask_id:
token_ids, segment_ids = masked_encode(text, mask_id)
else:
token_ids, segment_ids = tokenizer.encode(text, maxlen=MAXLEN)
output_ids = self.random_sample([token_ids, segment_ids], n,
topp=topp) # 基于随机采样
return [tokenizer.decode(ids) for ids in output_ids]
最后输出
上述更改完成后,进行测试
if __name__=='__main__':
# input_text = input('>> ')
#t1 = datetime.now()
mask_id = [6, 7]
input_text = '科学技术是第一生产力'
print(input_text)
mask_text = list(input_text)
for i in mask_id:
mask_text[i-1] = '[MASK]'
mask_text = ''.join(mask_text)
print(mask_text)
print(gen_synonyms(input_text, n=50, k=20, mask_id=mask_id))
# print('time cost:', datetime.now() - t1)
结果为:
最后非常非常感谢大佬指点(鞠躬感谢),后续随缘更新部分bert4keras部分模型微调或是代码解释,这个包确实比原生上手简单,可惜苏神文档开了个头就断更了~~。
PS:所有版本配置见源码或是苏神文章(都在开头)