利用matchzoo-py库实现房地产问答匹配问题
导入库包
# import matchzoo as mz
import pandas as pd
import numpy as np
import numpy as np
# import tensorflow.keras as K
# from matchzoo.preprocessors import BasicPreprocessor
from sklearn.model_selection import train_test_split
import torch
import matchzoo as mz
# import datetime
C:\Users\Administrator\Anaconda3\lib\requests\__init__.py:80: RequestsDependencyWarning: urllib3 (1.25.11) or chardet (3.0.4) doesn't match a supported version!
RequestsDependencyWarning)
数据处理
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
# import tensorflow as tf
# from keras.layers import *
# from keras.models import Model
# import keras.backend as K
# from keras.optimizers import Adam
from random import choice
# from keras_bert import load_trained_model_from_checkpoint, Tokenizer
import re, os
import codecs
# from keras.callbacks import Callback
#数据读取及处理
train_left = pd.read_csv('./train/train.query.tsv',sep='\t',header=None)
train_left.columns=['id','q1']
train_right = pd.read_csv('./train/train.reply.tsv',sep='\t',header=None)
train_right.columns=['id','id_sub','q2','label']
df_train = train_left.merge(train_right, how='left')
df_train['q2'] = df_train['q2'].fillna('好的')
test_left = pd.read_csv('./test/test.query.tsv',sep='\t',header=None, encoding='gbk')
test_left.columns = ['id','q1']
test_right = pd.read_csv('./test/test.reply.tsv',sep='\t',header=None, encoding='gbk')
test_right.columns=['id','id_sub','q2']
df_test = test_left.merge(test_right, how='left')
sent1_=df_train.q1.values[:100]
sent2_=df_train.q2.values[:100]
label_=df_train.label.values[:100]
from tqdm import tqdm
from fastHan import FastHan
model=FastHan(model_type='base')
loading vocabulary file C:\Users\Administrator\.fastNLP\fasthan\fasthan_base\vocab.txt
Load pre-trained BERT parameters from file C:\Users\Administrator\.fastNLP\fasthan\fasthan_base\model.bin.
def change(tmp):
tmp_list=[]
co=0
for i in tqdm(tmp):
print(co)
# print(model(i,'CWS'))
# print(type(model(i,'CWS')))
kkk=model(i,'CWS')
# print(' '.join(kkk[0]))
tmp_list.append(' '.join(kkk[0]))
co=co+1
return tmp_list
sent1_list=change(sent1_)
2%|█▋ | 2/100 [00:00<00:06, 14.40it/s]
0
1
2
3
6%|████▉ | 6/100 [00:00<00:05, 15.99it/s]
4
5
6
7
8
11%|████████▉ | 11/100 [00:00<00:05, 15.91it/s]
9
10
11
13%|██████████▌ | 13/100 [00:00<00:06, 13.94it/s]
12
13
14
17%|█████████████▊ | 17/100 [00:01<00:05, 14.74it/s]
15
16
17
18
21%|█████████████████ | 21/100 [00:01<00:05, 15.69it/s]
19
20
21
22
23
26%|█████████████████████ | 26/100 [00:01<00:04, 16.57it/s]
24
25
26
27
30%|████████████████████████▎ | 30/100 [00:01<00:04, 16.21it/s]
28
29
30
31
35%|████████████████████████████▎ | 35/100 [00:02<00:03, 17.70it/s]
32
33
34
35
39%|███████████████████████████████▌ | 39/100 [00:02<00:03, 18.36it/s]
36
37
38
39
43%|██████████████████████████████████▊ | 43/100 [00:02<00:03, 18.09it/s]
40
41
42
43
45%|████████████████████████████████████▍ | 45/100 [00:02<00:03, 17.19it/s]
44
45
46
47
50%|████████████████████████████████████████▌ | 50/100 [00:02<00:02, 16.78it/s]
48
49
50
51
54%|███████████████████████████████████████████▋ | 54/100 [00:03<00:02, 17.57it/s]
52
53
54
55
58%|██████████████████████████████████████████████▉ | 58/100 [00:03<00:02, 16.50it/s]
56
57
58
59
62%|██████████████████████████████████████████████████▏ | 62/100 [00:03<00:02, 16.90it/s]
60
61
62
63
68%|███████████████████████████████████████████████████████ | 68/100 [00:04<00:01, 18.25it/s]
64
65
66
67
70%|████████████████████████████████████████████████████████▋ | 70/100 [00:04<00:01, 18.28it/s]
68
69
70
71
72
73%|███████████████████████████████████████████████████████████▏ | 73/100 [00:04<00:01, 18.84it/s]
73
74
75%|████████████████████████████████████████████████████████████▊ | 75/100 [00:04<00:02, 12.50it/s]
75
77%|██████████████████████████████████████████████████████████████▎ | 77/100 [00:04<00:02, 10.99it/s]
76
77
78
81%|█████████████████████████████████████████████████████████████████▌ | 81/100 [00:05<00:01, 13.74it/s]
79
80
81
82
83
86%|█████████████████████████████████████████████████████████████████████▋ | 86/100 [00:05<00:00, 16.09it/s]
84
85
86
87
90%|████████████████████████████████████████████████████████████████████████▉ | 90/100 [00:05<00:00, 16.52it/s]
88
89
90
91
94%|████████████████████████████████████████████████████████████████████████████▏ | 94/100 [00:05<00:00, 15.81it/s]
92
93
94
95
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.51it/s]
96
97
98
99
sent2_list=change(sent2_)
2%|█▋ | 2/100 [00:00<00:06, 14.50it/s]
0
1
2
3
6%|████▉ | 6/100 [00:00<00:05, 16.34it/s]
4
5
6
7
10%|████████ | 10/100 [00:00<00:05, 16.40it/s]
8
9
10
12%|█████████▋ | 12/100 [00:00<00:05, 15.63it/s]
11
12
13
16%|████████████▉ | 16/100 [00:01<00:05, 14.34it/s]
14
15
16
17
20%|████████████████▏ | 20/100 [00:01<00:05, 13.89it/s]
18
19
20
21
24%|███████████████████▍ | 24/100 [00:01<00:04, 15.53it/s]
22
23
24
25
26
29%|███████████████████████▍ | 29/100 [00:01<00:04, 16.54it/s]
27
28
29
30
33%|██████████████████████████▋ | 33/100 [00:02<00:04, 16.57it/s]
31
32
33
34
37%|█████████████████████████████▉ | 37/100 [00:02<00:03, 17.75it/s]
35
36
37
38
41%|█████████████████████████████████▏ | 41/100 [00:02<00:03, 17.85it/s]
39
40
41
42
45%|████████████████████████████████████▍ | 45/100 [00:02<00:03, 16.02it/s]
43
44
45
46
49%|███████████████████████████████████████▋ | 49/100 [00:03<00:03, 16.45it/s]
47
48
49
50
53%|██████████████████████████████████████████▉ | 53/100 [00:03<00:02, 16.53it/s]
51
52
53
54
57%|██████████████████████████████████████████████▏ | 57/100 [00:03<00:02, 16.71it/s]
55
56
57
58
59%|███████████████████████████████████████████████▊ | 59/100 [00:03<00:02, 17.22it/s]
59
60
61%|█████████████████████████████████████████████████▍ | 61/100 [00:03<00:02, 13.05it/s]
61
62
63%|███████████████████████████████████████████████████ | 63/100 [00:04<00:03, 9.87it/s]
63
64
67%|██████████████████████████████████████████████████████▎ | 67/100 [00:04<00:02, 12.57it/s]
65
66
67
68
72%|██████████████████████████████████████████████████████████▎ | 72/100 [00:04<00:01, 15.77it/s]
69
70
71
72
73
77%|██████████████████████████████████████████████████████████████▎ | 77/100 [00:04<00:01, 17.58it/s]
74
75
76
77
78
83%|███████████████████████████████████████████████████████████████████▏ | 83/100 [00:05<00:00, 19.33it/s]
79
80
81
82
83
86%|█████████████████████████████████████████████████████████████████████▋ | 86/100 [00:05<00:00, 18.83it/s]
84
85
86
87
90%|████████████████████████████████████████████████████████████████████████▉ | 90/100 [00:05<00:00, 18.70it/s]
88
89
90
91
95%|████████████████████████████████████████████████████████████████████████████▉ | 95/100 [00:05<00:00, 19.67it/s]
92
93
94
95
96
97
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.38it/s]
98
99
all_data=pd.DataFrame()
all_data['text_left']=sent1_list
all_data['text_right']=sent2_list
all_data['id_left']=range(len(sent1_list))
all_data['id_right']=range(len(sent2_list))
all_data['label']=label_
_sent1=df_test.q1.values[:100]
_sent2=df_test.q2.values[:100]
# _label=label[2501:]
_sent1list=change(_sent1)
_sent2list=change(_sent2)
0%| | 0/100 [00:00<?, ?it/s]
0
2%|█▋ | 2/100 [00:00<00:06, 14.09it/s]
1
2
6%|████▉ | 6/100 [00:00<00:06, 15.32it/s]
3
4
5
6
11%|████████▉ | 11/100 [00:00<00:05, 17.48it/s]
7
8
9
10
11
15%|████████████▏ | 15/100 [00:00<00:04, 17.45it/s]
12
13
14
15
19%|███████████████▍ | 19/100 [00:01<00:04, 17.37it/s]
16
17
18
19
23%|██████████████████▋ | 23/100 [00:01<00:04, 16.29it/s]
20
21
22
23
27%|█████████████████████▊ | 27/100 [00:01<00:04, 16.96it/s]
24
25
26
27
31%|█████████████████████████ | 31/100 [00:01<00:03, 17.35it/s]
28
29
30
31
33%|██████████████████████████▋ | 33/100 [00:01<00:04, 16.34it/s]
32
33
34
37%|█████████████████████████████▉ | 37/100 [00:02<00:03, 15.93it/s]
35
36
37
38
41%|█████████████████████████████████▏ | 41/100 [00:02<00:03, 16.81it/s]
39
40
41
42
45%|████████████████████████████████████▍ | 45/100 [00:02<00:03, 17.04it/s]
43
44
45
46
49%|███████████████████████████████████████▋ | 49/100 [00:02<00:02, 17.11it/s]
47
48
49
50
51%|█████████████████████████████████████████▎ | 51/100 [00:03<00:02, 17.79it/s]
51
52
53%|██████████████████████████████████████████▉ | 53/100 [00:03<00:03, 12.23it/s]
53
54
57%|██████████████████████████████████████████████▏ | 57/100 [00:03<00:03, 11.09it/s]
55
56
57
61%|█████████████████████████████████████████████████▍ | 61/100 [00:03<00:03, 12.72it/s]
58
59
60
61
65%|████████████████████████████████████████████████████▋ | 65/100 [00:04<00:02, 13.78it/s]
62
63
64
65
69%|███████████████████████████████████████████████████████▉ | 69/100 [00:04<00:02, 14.98it/s]
66
67
68
69
73%|███████████████████████████████████████████████████████████▏ | 73/100 [00:04<00:01, 16.94it/s]
70
71
72
73
74
77%|██████████████████████████████████████████████████████████████▎ | 77/100 [00:04<00:01, 17.28it/s]
75
76
77
78
81%|█████████████████████████████████████████████████████████████████▌ | 81/100 [00:05<00:01, 17.84it/s]
79
80
81
82
83
86%|█████████████████████████████████████████████████████████████████████▋ | 86/100 [00:05<00:00, 18.67it/s]
84
85
86
87
90%|████████████████████████████████████████████████████████████████████████▉ | 90/100 [00:05<00:00, 17.55it/s]
88
89
90
91
94%|████████████████████████████████████████████████████████████████████████████▏ | 94/100 [00:05<00:00, 13.44it/s]
92
93
94
98%|███████████████████████████████████████████████████████████████████████████████▍ | 98/100 [00:06<00:00, 14.87it/s]
95
96
97
98
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.73it/s]
2%|█▋ | 2/100 [00:00<00:05, 19.06it/s]
99
0
1
2
6%|████▉ | 6/100 [00:00<00:05, 17.73it/s]
3
4
5
6
10%|████████ | 10/100 [00:00<00:04, 18.02it/s]
7
8
9
10
14%|███████████▎ | 14/100 [00:00<00:04, 17.49it/s]
11
12
13
14
19%|███████████████▍ | 19/100 [00:01<00:04, 17.72it/s]
15
16
17
18
19
22%|█████████████████▊ | 22/100 [00:01<00:04, 17.55it/s]
20
21
22
23
26%|█████████████████████ | 26/100 [00:01<00:04, 17.85it/s]
24
25
26
27
30%|████████████████████████▎ | 30/100 [00:01<00:03, 18.53it/s]
28
29
30
31
32
36%|█████████████████████████████▏ | 36/100 [00:01<00:03, 20.68it/s]
33
34
35
36
37
42%|██████████████████████████████████ | 42/100 [00:02<00:02, 21.40it/s]
38
39
40
41
42
43
45%|████████████████████████████████████▍ | 45/100 [00:02<00:04, 11.67it/s]
44
45
47%|██████████████████████████████████████ | 47/100 [00:02<00:04, 11.74it/s]
46
47
48
51%|█████████████████████████████████████████▎ | 51/100 [00:03<00:03, 12.87it/s]
49
50
51
52
55%|████████████████████████████████████████████▌ | 55/100 [00:03<00:03, 14.34it/s]
53
54
55
56
59%|███████████████████████████████████████████████▊ | 59/100 [00:03<00:02, 15.54it/s]
57
58
59
60
63%|███████████████████████████████████████████████████ | 63/100 [00:03<00:02, 14.36it/s]
61
62
63
64
65%|████████████████████████████████████████████████████▋ | 65/100 [00:04<00:02, 13.50it/s]
65
66
69%|███████████████████████████████████████████████████████▉ | 69/100 [00:04<00:02, 14.14it/s]
67
68
69
70
72%|██████████████████████████████████████████████████████████▎ | 72/100 [00:04<00:01, 15.60it/s]
71
72
73
74
77%|██████████████████████████████████████████████████████████████▎ | 77/100 [00:04<00:01, 15.36it/s]
75
76
77
79%|███████████████████████████████████████████████████████████████▉ | 79/100 [00:05<00:01, 15.03it/s]
78
79
80
83%|███████████████████████████████████████████████████████████████████▏ | 83/100 [00:05<00:01, 14.02it/s]
81
82
83
88%|███████████████████████████████████████████████████████████████████████▎ | 88/100 [00:05<00:00, 15.69it/s]
84
85
86
87
88
90%|████████████████████████████████████████████████████████████████████████▉ | 90/100 [00:05<00:00, 14.71it/s]
89
90
91
92
96%|█████████████████████████████████████████████████████████████████████████████▊ | 96/100 [00:06<00:00, 15.95it/s]
93
94
95
96
98%|███████████████████████████████████████████████████████████████████████████████▍ | 98/100 [00:06<00:00, 12.88it/s]
97
98
99
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.50it/s]
# train_split = int(len(df_train) * 0.7) # 训练集占训练数据的0.7
# dev_split = int(len(df_train) * 0.9) # 测试集占训练数据的0.2
# # test_split = int(len(data_df) * 0.6) # 验证集占训练数据的0.1
train_split = int(len(sent1_) * 0.7) # 训练集占训练数据的0.7
dev_split = int(len(sent1_) * 0.9) # 测试集占训练数据的0.2
# test_split = int(len(data_df) * 0.6) # 验证集占训练数据的0.1
tmp_data=pd.DataFrame()
tmp_data['text_left']=_sent1list
tmp_data['text_right']=_sent2list
tmp_data['id_left']=range(len(_sent1list))
tmp_data['id_right']=range(len(_sent2list))
# test_data['label']=_label
cls_task=mz.tasks.Classification(num_classes=2)
cls_task.metrics=[mz.metrics.Accuracy]
def load_data(df_data):
# df_data = pd.read_csv(data_path, sep='\t', header=None)
# df_data = pd.DataFrame(df_data.values, columns=['id_left', 'text_left', 'id_right', 'text_right', 'label'])
df_data = mz.pack(df_data,task=cls_task)
return df_data
train_data = load_data(all_data)
test_data=load_data(tmp_data)
train_split
70
train = train_data[:train_split]
dev = train_data[train_split:]
# train_pack_processed = preprocessor.fit_transform(train)
# # 其实就是做了一个字符转id操作,所以对于中文文本,不需要分词
# dev_pack_processed = preprocessor.transform(dev)
# dev
train_split
70
preprocessor = mz.models.ArcI.get_default_preprocessor()
train_processed = preprocessor.fit_transform(train)
valid_processed = preprocessor.transform(dev)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 1459.16it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 1522.59it/s]
Processing text_right with append: 100%|████████████████████████████████████████████| 70/70 [00:00<00:00, 35027.59it/s]
Building FrequencyFilter from a datapack.: 100%|████████████████████████████████████| 70/70 [00:00<00:00, 17510.66it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 70/70 [00:00<00:00, 35031.77it/s]
Processing text_left with extend: 100%|█████████████████████████████████████████████| 70/70 [00:00<00:00, 23349.87it/s]
Processing text_right with extend: 100%|████████████████████████████████████████████| 70/70 [00:00<00:00, 35077.81it/s]
Building Vocabulary from a datapack.: 100%|██████████████████████████████████████| 812/812 [00:00<00:00, 812252.53it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 2122.50it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 1591.75it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 70/70 [00:00<00:00, 17503.36it/s]
Processing text_left with transform: 100%|██████████████████████████████████████████| 70/70 [00:00<00:00, 11674.94it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 70/70 [00:00<00:00, 11661.95it/s]
Processing length_left with len: 100%|██████████████████████████████████████████████| 70/70 [00:00<00:00, 35015.06it/s]
Processing length_right with len: 100%|█████████████████████████████████████████████| 70/70 [00:00<00:00, 35035.95it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 30/30 [00:00<00:00, 1579.85it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 30/30 [00:00<00:00, 2728.30it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 30/30 [00:00<00:00, 10002.31it/s]
Processing text_left with transform: 100%|███████████████████████████████████████████| 30/30 [00:00<00:00, 2501.87it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 30/30 [00:00<00:00, 15036.94it/s]
Processing length_left with len: 100%|███████████████████████████████████████████████| 30/30 [00:00<00:00, 6001.29it/s]
Processing length_right with len: 100%|█████████████████████████████████████████████| 30/30 [00:00<00:00, 29987.87it/s]
valid_processed.unpack()
({'id_left': array([70, 71, 72, 73, 74, 77, 78, 79, 80, 81, 83, 84, 85, 87, 88, 89, 90,
91, 92, 93, 95, 97, 98, 99]),
'id_right': array([70, 71, 72, 73, 74, 77, 78, 79, 80, 81, 83, 84, 85, 87, 88, 89, 90,
91, 92, 93, 95, 97, 98, 99]),
'length_left': array([ 4, 4, 7, 7, 7, 6, 5, 5, 5, 5, 5, 5, 5, 8, 8, 10, 10,
10, 10, 10, 2, 4, 4, 4]),
'length_right': array([3, 2, 2, 3, 1, 2, 1, 1, 7, 2, 5, 1, 4, 5, 1, 4, 6, 3, 3, 1, 2, 7,
1, 3]),
'text_left': array([list([144, 197, 15, 151]), list([144, 197, 15, 151]),
list([238, 30, 151, 49, 220, 228, 86]),
list([238, 30, 151, 49, 220, 228, 86]),
list([238, 30, 151, 49, 220, 228, 86]),
list([146, 144, 198, 1, 195, 151]), list([1, 236, 220, 228, 86]),
list([1, 236, 220, 228, 86]), list([1, 236, 220, 228, 86]),
list([238, 166, 14, 172, 86]), list([238, 166, 14, 172, 86]),
list([238, 166, 14, 172, 86]), list([238, 166, 14, 172, 86]),
list([142, 117, 146, 144, 197, 15, 238, 151]),
list([142, 117, 146, 144, 197, 15, 238, 151]),
list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]), list([1, 86]),
list([61, 197, 151, 99]), list([61, 197, 151, 99]),
list([61, 197, 151, 99])], dtype=object),
'text_right': array([list([142, 45, 197]), list([117, 195]), list([142, 117]),
list([49, 83, 228]), list([195]), list([83, 195]), list([83]),
list([195]), list([147, 170, 251, 195, 56, 170, 164]),
list([78, 197]), list([146, 22, 30, 75, 38]), list([110]),
list([195, 56, 188, 117]), list([146, 215, 142, 80, 30]),
list([38]), list([256, 12, 137, 142]),
list([8, 177, 70, 195, 7, 177]), list([244, 195, 122]),
list([202, 42, 136]), list([195]), list([217, 86]),
list([238, 115, 151, 166, 250, 195, 137]), list([142]),
list([221, 236, 166])], dtype=object)},
array([[1],
[0],
[0],
[1],
[0],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0]]))
train_processed.unpack()
({'id_left': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
69]),
'id_right': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
69]),
'length_left': array([ 6, 6, 6, 2, 2, 2, 2, 2, 10, 10, 10, 10, 10, 3, 3, 3, 9,
9, 9, 4, 4, 4, 4, 4, 9, 9, 9, 7, 7, 7, 5, 5, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 9, 9, 2, 2, 2, 8, 8, 8, 4,
4, 4, 9, 9, 9, 9, 2, 2, 2, 5, 5, 5, 3, 3, 3, 3, 3,
4]),
'length_right': array([11, 2, 4, 6, 2, 9, 7, 6, 8, 4, 9, 2, 10, 7, 5, 10, 1,
18, 27, 3, 5, 4, 8, 12, 4, 2, 16, 8, 1, 5, 8, 5, 2, 7,
3, 6, 7, 9, 1, 4, 10, 10, 7, 6, 2, 2, 5, 9, 12, 13, 9,
2, 2, 8, 5, 9, 7, 3, 5, 2, 8, 5, 5, 5, 4, 8, 11, 5,
2]),
'text_left': array([list([249, 14, 126, 166, 68, 87]),
list([249, 14, 126, 166, 68, 87]),
list([249, 14, 126, 166, 68, 87]), list([180, 86]),
list([180, 86]), list([180, 86]), list([180, 86]), list([180, 86]),
list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
list([263, 131, 99]), list([263, 131, 99]), list([263, 131, 99]),
list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
list([239, 206, 108, 252]), list([239, 206, 108, 252]),
list([239, 206, 108, 252]), list([239, 206, 108, 252]),
list([239, 206, 108, 252]),
list([238, 115, 154, 49, 236, 170, 50, 208, 33]),
list([238, 115, 154, 49, 236, 170, 50, 208, 33]),
list([238, 115, 154, 49, 236, 170, 50, 208, 33]),
list([170, 151, 83, 133, 146, 198, 86]),
list([170, 151, 83, 133, 146, 198, 86]),
list([170, 151, 83, 133, 146, 198, 86]),
list([62, 27, 117, 166, 86]), list([62, 27, 117, 166, 86]),
list([67, 172, 195]), list([67, 172, 195]), list([67, 172, 195]),
list([67, 172, 195]), list([67, 172, 195]), list([45, 143, 88]),
list([45, 143, 88]), list([45, 143, 88]), list([45, 143, 88]),
list([45, 143, 88]),
list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
list([238, 115, 154, 49, 236, 170, 50, 208, 86]), list([234, 86]),
list([234, 86]), list([234, 86]),
list([165, 236, 138, 78, 197, 15, 238, 115]),
list([165, 236, 138, 78, 197, 15, 238, 115]),
list([165, 236, 138, 78, 197, 15, 238, 115]),
list([197, 150, 161, 86]), list([197, 150, 161, 86]),
list([197, 150, 161, 86]),
list([238, 30, 127, 247, 95, 30, 128, 90, 31]),
list([238, 30, 127, 247, 95, 30, 128, 90, 31]),
list([238, 30, 127, 247, 95, 30, 128, 90, 31]),
list([238, 30, 127, 247, 95, 30, 128, 90, 31]), list([265, 108]),
list([265, 108]), list([265, 108]), list([49, 168, 55, 108, 75]),
list([49, 168, 55, 108, 75]), list([49, 168, 55, 108, 75]),
list([248, 141, 88]), list([248, 141, 88]), list([248, 141, 88]),
list([248, 141, 88]), list([248, 141, 88]),
list([144, 197, 15, 151])], dtype=object),
'text_right': array([list([173, 249, 210, 128, 254, 175, 173, 253, 159, 121, 119]),
list([166, 195]), list([238, 166, 7, 177]),
list([100, 65, 230, 229, 195, 129]), list([166, 88]),
list([238, 115, 14, 177, 133, 259, 195, 142, 198]),
list([152, 246, 166, 42, 136, 107, 195]),
list([117, 195, 142, 61, 197, 23]),
list([142, 166, 264, 115, 237, 40, 115, 88]),
list([155, 231, 24, 10]),
list([73, 199, 207, 90, 147, 195, 171, 90, 153]), list([117, 195]),
list([256, 112, 12, 194, 166, 13, 85, 139, 260, 218]),
list([178, 255, 230, 83, 221, 236, 195]),
list([170, 14, 188, 263, 131]),
list([81, 170, 245, 14, 115, 180, 195, 27, 263, 131]), list([170]),
list([182, 186, 183, 195, 160, 104, 5, 133, 225, 198, 78, 86, 50, 189, 137, 109, 43, 191]),
list([124, 250, 219, 74, 103, 7, 8, 84, 214, 163, 202, 235, 42, 136, 96, 4, 177, 76, 244, 59, 29, 200, 75, 197, 150, 170, 251]),
list([142, 198, 239]), list([3, 177, 211, 225, 195]),
list([142, 192, 161, 86]),
list([83, 61, 214, 22, 133, 142, 197, 15]),
list([142, 205, 15, 215, 142, 80, 257, 142, 188, 242, 130, 224]),
list([49, 89, 82, 228]), list([142, 117]),
list([142, 117, 142, 36, 91, 195, 154, 149, 137, 51, 142, 170, 164, 78, 198, 86]),
list([123, 222, 233, 170, 115, 11, 135, 195]), list([209]),
list([83, 133, 142, 197, 15]),
list([62, 262, 117, 79, 223, 246, 166, 83]),
list([56, 190, 101, 243, 198]), list([166, 195]),
list([238, 30, 166, 40, 172, 195, 216]), list([236, 170, 195]),
list([40, 172, 246, 69, 26, 47]),
list([14, 172, 195, 170, 63, 195, 38]),
list([147, 25, 179, 167, 192, 37, 220, 197, 150]), list([94]),
list([110, 166, 241, 195]),
list([60, 37, 52, 176, 145, 49, 204, 170, 17, 185]),
list([28, 238, 203, 156, 151, 185, 208, 246, 27, 109]),
list([238, 115, 151, 49, 236, 170, 208]),
list([66, 226, 197, 142, 48, 162]), list([125, 195]),
list([21, 195]), list([142, 144, 39, 15, 45]),
list([165, 78, 198, 92, 2, 116, 118, 37, 198]),
list([146, 134, 56, 174, 23, 98, 69, 71, 146, 197, 23, 212]),
list([146, 46, 232, 44, 197, 160, 38, 213, 38, 43, 113, 187, 38]),
list([238, 115, 151, 102, 6, 177, 166, 32, 177]), list([56, 117]),
list([18, 148]), list([128, 166, 201, 19, 266, 35, 106, 128]),
list([31, 166, 140, 184, 195]),
list([258, 195, 170, 261, 16, 158, 77, 31, 240]),
list([169, 195, 166, 261, 90, 16, 158]), list([264, 115, 150]),
list([265, 72, 120, 9, 132]), list([73, 64]),
list([181, 227, 54, 53, 129, 27, 70, 108]),
list([4, 177, 151, 75, 38]), list([238, 115, 151, 83, 93]),
list([260, 83, 197, 150, 195]), list([46, 170, 164, 86]),
list([146, 157, 251, 83, 196, 193, 197, 195]),
list([117, 195, 130, 217, 34, 238, 105, 126, 20, 122, 86]),
list([81, 197, 41, 250, 86]), list([83, 97])], dtype=object)},
array([[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[0],
[0],
[0],
[1],
[0],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[0],
[0],
[1],
[1],
[1],
[1],
[0],
[0],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1]]))
trainset = mz.dataloader.Dataset(
data_pack=train_processed,
mode='point',
# num_dup=1,
# num_neg=4,
batch_size=32
)
validset = mz.dataloader.Dataset(
data_pack=valid_processed,
mode='point',
batch_size=32
)
padding_callback = mz.models.ArcI.get_default_padding_callback()
trainloader = mz.dataloader.DataLoader(
dataset=trainset,
stage='train',
callback=padding_callback
)
validloader = mz.dataloader.DataLoader(
dataset=validset,
stage='dev',
callback=padding_callback
)
模型搭建
model = mz.models.ArcI()
model.params['task'] = cls_task
model.params['embedding_output_dim'] = 100
model.params['embedding_input_dim'] = preprocessor.context['embedding_input_dim']
model.guess_and_fill_missing_params()
model.build()
print(model)
ArcI(
(embedding): Embedding(267, 100, padding_idx=0)
(conv_left): Sequential(
(0): Sequential(
(0): ConstantPad1d(padding=(0, 2), value=0)
(1): Conv1d(100, 32, kernel_size=(3,), stride=(1,))
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
)
(conv_right): Sequential(
(0): Sequential(
(0): ConstantPad1d(padding=(0, 2), value=0)
(1): Conv1d(100, 32, kernel_size=(3,), stride=(1,))
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
)
(dropout): Dropout(p=0.0, inplace=False)
(mlp): Sequential(
(0): Sequential(
(0): Linear(in_features=1760, out_features=128, bias=True)
(1): ReLU()
)
(1): Sequential(
(0): Linear(in_features=128, out_features=128, bias=True)
(1): ReLU()
)
(2): Sequential(
(0): Linear(in_features=128, out_features=128, bias=True)
(1): ReLU()
)
(3): Sequential(
(0): Linear(in_features=128, out_features=64, bias=True)
(1): ReLU()
)
)
(out): Linear(in_features=64, out_features=2, bias=True)
)
optimizer = torch.optim.Adam(model.parameters())
trainer = mz.trainers.Trainer(
model=model,
optimizer=optimizer,
trainloader=trainloader,
validloader=validloader,
epochs=10
)
trainer.run()
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-3 Loss-0.709]:
Validation: accuracy: 0.625
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-6 Loss-0.684]:
Validation: accuracy: 0.6667
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-9 Loss-0.603]:
Validation: accuracy: 0.6667
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-12 Loss-0.508]:
Validation: accuracy: 0.6667
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-15 Loss-0.379]:
Validation: accuracy: 0.6667
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-18 Loss-0.389]:
Validation: accuracy: 0.6667
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-21 Loss-0.438]:
Validation: accuracy: 0.75
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-24 Loss-0.327]:
Validation: accuracy: 0.8333
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-27 Loss-0.298]:
Validation: accuracy: 0.8333
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))
[Iter-30 Loss-0.255]:
Validation: accuracy: 0.75
Cost time: 4.359504461288452s