bert以首字表示词向量(2)

第二篇文章,通过一种新的方式来实现以首字表示词向量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

# -*- coding: utf8 -*-
#
from typing import List
from unittest import TestCase

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModel, PreTrainedTokenizerBase


def tokenize(form: List[List[str]], tokenizer: PreTrainedTokenizerBase, max_length: int, char_base: bool = False):
"""

Args:
form:
tokenizer:
max_length:
char_base: 这里指的是form[即 word]是否是字级别的

Returns:

"""
res = tokenizer.batch_encode_plus(
form,
is_split_into_words=True,
max_length=max_length,
truncation=True,
)
result = res.data
# 可用于长度大于指定长度过滤, overflow指字长度大于指定max_length,如果有cls,sep,那么就算上这个
result['overflow'] = [len(encoding.overflowing) > 0 for encoding in res.encodings]
if not char_base:
word_index = []
for encoding in res.encodings:
word_index.append([])

last_word_idx = -1
current_length = 0
for word_idx in encoding.word_ids[1:-1]:
if word_idx != last_word_idx:
word_index[-1].append(current_length)

current_length += 1
last_word_idx = word_idx
result['word_index'] = word_index
result['word_attention_mask'] = [[True] * len(index) for index in word_index]
return result


class TestSample(TestCase):
def test_max_length(self):
"""
测试max_length overflow情况
:return:
"""
pass

def test_sample(self):
form = [
['我', '呀'],
['我', '小明', '呀']
]

tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-electra-180g-small-discriminator')
result = tokenize(form, tokenizer, 6)
model = AutoModel.from_pretrained('hfl/chinese-electra-180g-small-discriminator')

input_ids = pad_sequence([torch.tensor(input_ids) for input_ids in result['input_ids']], batch_first=True)
token_type_ids = pad_sequence([torch.tensor(token_type_ids) for token_type_ids in result['token_type_ids']],
batch_first=True)
attention_mask = pad_sequence([torch.tensor(attention_mask) for attention_mask in result['attention_mask']],
batch_first=True)

# tensor([[ 101, 2769, 1435, 102, 0, 0],
# [ 101, 2769, 2207, 3209, 1435, 102]])

# 1. 获取bert output.
bert_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
seq_out = bert_out[0]

word_index = pad_sequence([torch.tensor(word_index) for word_index in result['word_index']], batch_first=True)

# 2. 获取词首字向量,包括cls开头
word_out = torch.cat([seq_out[:, :1, :], torch.gather(
seq_out[:, 1:, :], dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, seq_out.size(-1))
)], dim=1)

word_attention_mask = pad_sequence(
[torch.tensor(word_attention_mask) for word_attention_mask in result['word_attention_mask']],
batch_first=True)

# 这里方便view
# 1. ['我', '呀']
self.assertTrue((seq_out[0][0] == word_out[0][0]).all()) # cls
self.assertTrue((seq_out[0][1] == word_out[0][1]).all()) # 我
self.assertTrue((seq_out[0][2] == word_out[0][2]).all()) # 呀

self.assertTrue((word_out[0][1] == word_out[0][3]).all()) # 填充位

# 2. ['我', '小明', '呀']
self.assertTrue((seq_out[1][0] == word_out[1][0]).all()) # cls
self.assertTrue((seq_out[1][1] == word_out[1][1]).all()) # 我
self.assertTrue((seq_out[1][2] == word_out[1][2]).all()) # 小明
self.assertTrue((seq_out[1][4] == word_out[1][3]).all()) # 呀

# 3. Note: word_out的时候concat了seq_out[:, :1, :](cls),所以word_out的长度比word_attention_mask大1
self.assertEqual(word_out.size(1), word_attention_mask.size(1) + 1)

# 4. 获取每个词对应的向量
result = word_out[:, 1:, :][word_attention_mask]
result2 = result.split(word_attention_mask.sum(1).tolist())
self.assertEqual(len(result2[0]), 2)
self.assertEqual(len(result2[1]), 3)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值