1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
from typing import List from unittest import TestCase
import torch from torch.nn.utils.rnn import pad_sequence from transformers import AutoTokenizer, AutoModel, PreTrainedTokenizerBase
def tokenize(form: List[List[str]], tokenizer: PreTrainedTokenizerBase, max_length: int, char_base: bool = False): """
Args: form: tokenizer: max_length: char_base: 这里指的是form[即 word]是否是字级别的
Returns:
""" res = tokenizer.batch_encode_plus( form, is_split_into_words=True, max_length=max_length, truncation=True, ) result = res.data result['overflow'] = [len(encoding.overflowing) > 0 for encoding in res.encodings] if not char_base: word_index = [] for encoding in res.encodings: word_index.append([])
last_word_idx = -1 current_length = 0 for word_idx in encoding.word_ids[1:-1]: if word_idx != last_word_idx: word_index[-1].append(current_length)
current_length += 1 last_word_idx = word_idx result['word_index'] = word_index result['word_attention_mask'] = [[True] * len(index) for index in word_index] return result
class TestSample(TestCase): def test_max_length(self): """ 测试max_length overflow情况 :return: """ pass
def test_sample(self): form = [ ['我', '呀'], ['我', '小明', '呀'] ]
tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-electra-180g-small-discriminator') result = tokenize(form, tokenizer, 6) model = AutoModel.from_pretrained('hfl/chinese-electra-180g-small-discriminator')
input_ids = pad_sequence([torch.tensor(input_ids) for input_ids in result['input_ids']], batch_first=True) token_type_ids = pad_sequence([torch.tensor(token_type_ids) for token_type_ids in result['token_type_ids']], batch_first=True) attention_mask = pad_sequence([torch.tensor(attention_mask) for attention_mask in result['attention_mask']], batch_first=True)
bert_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) seq_out = bert_out[0]
word_index = pad_sequence([torch.tensor(word_index) for word_index in result['word_index']], batch_first=True)
word_out = torch.cat([seq_out[:, :1, :], torch.gather( seq_out[:, 1:, :], dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, seq_out.size(-1)) )], dim=1)
word_attention_mask = pad_sequence( [torch.tensor(word_attention_mask) for word_attention_mask in result['word_attention_mask']], batch_first=True)
self.assertTrue((seq_out[0][0] == word_out[0][0]).all()) self.assertTrue((seq_out[0][1] == word_out[0][1]).all()) self.assertTrue((seq_out[0][2] == word_out[0][2]).all())
self.assertTrue((word_out[0][1] == word_out[0][3]).all())
self.assertTrue((seq_out[1][0] == word_out[1][0]).all()) self.assertTrue((seq_out[1][1] == word_out[1][1]).all()) self.assertTrue((seq_out[1][2] == word_out[1][2]).all()) self.assertTrue((seq_out[1][4] == word_out[1][3]).all())
self.assertEqual(word_out.size(1), word_attention_mask.size(1) + 1)
result = word_out[:, 1:, :][word_attention_mask] result2 = result.split(word_attention_mask.sum(1).tolist()) self.assertEqual(len(result2[0]), 2) self.assertEqual(len(result2[1]), 3)
|