房地产问答机器人操作当前较为完整版本呱呱呱

利用matchzoo-py库实现房地产问答匹配问题

导入库包

# import matchzoo as mz
import pandas as pd
import numpy as np
import numpy as np 
# import tensorflow.keras as K
# from matchzoo.preprocessors import BasicPreprocessor
from sklearn.model_selection import train_test_split
import torch
import matchzoo as mz
# import datetime
C:\Users\Administrator\Anaconda3\lib\requests\__init__.py:80: RequestsDependencyWarning: urllib3 (1.25.11) or chardet (3.0.4) doesn't match a supported version!
  RequestsDependencyWarning)

数据处理

from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
# import tensorflow as tf
# from keras.layers import *
# from keras.models import Model
# import keras.backend as K
# from keras.optimizers import Adam
from random import choice
# from keras_bert import load_trained_model_from_checkpoint, Tokenizer
import re, os
import codecs
# from keras.callbacks import Callback
#数据读取及处理
train_left = pd.read_csv('./train/train.query.tsv',sep='\t',header=None)
train_left.columns=['id','q1']
train_right = pd.read_csv('./train/train.reply.tsv',sep='\t',header=None)
train_right.columns=['id','id_sub','q2','label']
df_train = train_left.merge(train_right, how='left')
df_train['q2'] = df_train['q2'].fillna('好的')
test_left = pd.read_csv('./test/test.query.tsv',sep='\t',header=None, encoding='gbk')
test_left.columns = ['id','q1']
test_right =  pd.read_csv('./test/test.reply.tsv',sep='\t',header=None, encoding='gbk')
test_right.columns=['id','id_sub','q2']
df_test = test_left.merge(test_right, how='left')
sent1_=df_train.q1.values[:100]
sent2_=df_train.q2.values[:100]
label_=df_train.label.values[:100]
from tqdm import tqdm
from fastHan import FastHan
model=FastHan(model_type='base')
loading vocabulary file C:\Users\Administrator\.fastNLP\fasthan\fasthan_base\vocab.txt
Load pre-trained BERT parameters from file C:\Users\Administrator\.fastNLP\fasthan\fasthan_base\model.bin.
def change(tmp):
    tmp_list=[]
    co=0
    for i in tqdm(tmp):
        print(co)
#         print(model(i,'CWS'))
#         print(type(model(i,'CWS')))
        kkk=model(i,'CWS')
#         print(' '.join(kkk[0]))
        tmp_list.append(' '.join(kkk[0]))
        co=co+1
    return tmp_list
sent1_list=change(sent1_)
  2%|█▋                                                                                | 2/100 [00:00<00:06, 14.40it/s]

0
1
2
3


  6%|████▉                                                                             | 6/100 [00:00<00:05, 15.99it/s]

4
5
6
7
8

 11%|████████▉                                                                        | 11/100 [00:00<00:05, 15.91it/s]


9
10
11


 13%|██████████▌                                                                      | 13/100 [00:00<00:06, 13.94it/s]

12
13
14


 17%|█████████████▊                                                                   | 17/100 [00:01<00:05, 14.74it/s]

15
16
17
18


 21%|█████████████████                                                                | 21/100 [00:01<00:05, 15.69it/s]

19
20
21
22
23


 26%|█████████████████████                                                            | 26/100 [00:01<00:04, 16.57it/s]

24
25
26
27


 30%|████████████████████████▎                                                        | 30/100 [00:01<00:04, 16.21it/s]

28
29
30
31


 35%|████████████████████████████▎                                                    | 35/100 [00:02<00:03, 17.70it/s]

32
33
34
35


 39%|███████████████████████████████▌                                                 | 39/100 [00:02<00:03, 18.36it/s]

36
37
38
39


 43%|██████████████████████████████████▊                                              | 43/100 [00:02<00:03, 18.09it/s]

40
41
42
43


 45%|████████████████████████████████████▍                                            | 45/100 [00:02<00:03, 17.19it/s]

44
45
46
47


 50%|████████████████████████████████████████▌                                        | 50/100 [00:02<00:02, 16.78it/s]

48
49
50
51


 54%|███████████████████████████████████████████▋                                     | 54/100 [00:03<00:02, 17.57it/s]

52
53
54
55


 58%|██████████████████████████████████████████████▉                                  | 58/100 [00:03<00:02, 16.50it/s]

56
57
58
59


 62%|██████████████████████████████████████████████████▏                              | 62/100 [00:03<00:02, 16.90it/s]

60
61
62
63


 68%|███████████████████████████████████████████████████████                          | 68/100 [00:04<00:01, 18.25it/s]

64
65
66
67


 70%|████████████████████████████████████████████████████████▋                        | 70/100 [00:04<00:01, 18.28it/s]

68
69
70
71
72

 73%|███████████████████████████████████████████████████████████▏                     | 73/100 [00:04<00:01, 18.84it/s]


73
74

 75%|████████████████████████████████████████████████████████████▊                    | 75/100 [00:04<00:02, 12.50it/s]


75


 77%|██████████████████████████████████████████████████████████████▎                  | 77/100 [00:04<00:02, 10.99it/s]

76
77
78


 81%|█████████████████████████████████████████████████████████████████▌               | 81/100 [00:05<00:01, 13.74it/s]

79
80
81
82
83

 86%|█████████████████████████████████████████████████████████████████████▋           | 86/100 [00:05<00:00, 16.09it/s]


84
85
86
87


 90%|████████████████████████████████████████████████████████████████████████▉        | 90/100 [00:05<00:00, 16.52it/s]

88
89
90
91


 94%|████████████████████████████████████████████████████████████████████████████▏    | 94/100 [00:05<00:00, 15.81it/s]

92
93
94
95


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.51it/s]

96
97
98
99
sent2_list=change(sent2_)
  2%|█▋                                                                                | 2/100 [00:00<00:06, 14.50it/s]

0
1
2
3


  6%|████▉                                                                             | 6/100 [00:00<00:05, 16.34it/s]

4
5
6
7


 10%|████████                                                                         | 10/100 [00:00<00:05, 16.40it/s]

8
9
10


 12%|█████████▋                                                                       | 12/100 [00:00<00:05, 15.63it/s]

11
12
13


 16%|████████████▉                                                                    | 16/100 [00:01<00:05, 14.34it/s]

14
15
16
17


 20%|████████████████▏                                                                | 20/100 [00:01<00:05, 13.89it/s]

18
19
20
21


 24%|███████████████████▍                                                             | 24/100 [00:01<00:04, 15.53it/s]

22
23
24
25
26


 29%|███████████████████████▍                                                         | 29/100 [00:01<00:04, 16.54it/s]

27
28
29
30


 33%|██████████████████████████▋                                                      | 33/100 [00:02<00:04, 16.57it/s]

31
32
33
34


 37%|█████████████████████████████▉                                                   | 37/100 [00:02<00:03, 17.75it/s]

35
36
37
38


 41%|█████████████████████████████████▏                                               | 41/100 [00:02<00:03, 17.85it/s]

39
40
41
42


 45%|████████████████████████████████████▍                                            | 45/100 [00:02<00:03, 16.02it/s]

43
44
45
46


 49%|███████████████████████████████████████▋                                         | 49/100 [00:03<00:03, 16.45it/s]

47
48
49
50


 53%|██████████████████████████████████████████▉                                      | 53/100 [00:03<00:02, 16.53it/s]

51
52
53
54


 57%|██████████████████████████████████████████████▏                                  | 57/100 [00:03<00:02, 16.71it/s]

55
56
57
58


 59%|███████████████████████████████████████████████▊                                 | 59/100 [00:03<00:02, 17.22it/s]

59
60


 61%|█████████████████████████████████████████████████▍                               | 61/100 [00:03<00:02, 13.05it/s]

61
62

 63%|███████████████████████████████████████████████████                              | 63/100 [00:04<00:03,  9.87it/s]


63
64


 67%|██████████████████████████████████████████████████████▎                          | 67/100 [00:04<00:02, 12.57it/s]

65
66
67
68


 72%|██████████████████████████████████████████████████████████▎                      | 72/100 [00:04<00:01, 15.77it/s]

69
70
71
72
73


 77%|██████████████████████████████████████████████████████████████▎                  | 77/100 [00:04<00:01, 17.58it/s]

74
75
76
77
78

 83%|███████████████████████████████████████████████████████████████████▏             | 83/100 [00:05<00:00, 19.33it/s]


79
80
81
82
83


 86%|█████████████████████████████████████████████████████████████████████▋           | 86/100 [00:05<00:00, 18.83it/s]

84
85
86
87


 90%|████████████████████████████████████████████████████████████████████████▉        | 90/100 [00:05<00:00, 18.70it/s]

88
89
90
91


 95%|████████████████████████████████████████████████████████████████████████████▉    | 95/100 [00:05<00:00, 19.67it/s]

92
93
94
95
96
97

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.38it/s]


98
99
all_data=pd.DataFrame()
all_data['text_left']=sent1_list
all_data['text_right']=sent2_list
all_data['id_left']=range(len(sent1_list))
all_data['id_right']=range(len(sent2_list))
all_data['label']=label_
_sent1=df_test.q1.values[:100]
_sent2=df_test.q2.values[:100]
# _label=label[2501:]
_sent1list=change(_sent1)
_sent2list=change(_sent2)
  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

0

  2%|█▋                                                                                | 2/100 [00:00<00:06, 14.09it/s]


1
2


  6%|████▉                                                                             | 6/100 [00:00<00:06, 15.32it/s]

3
4
5
6


 11%|████████▉                                                                        | 11/100 [00:00<00:05, 17.48it/s]

7
8
9
10
11


 15%|████████████▏                                                                    | 15/100 [00:00<00:04, 17.45it/s]

12
13
14
15


 19%|███████████████▍                                                                 | 19/100 [00:01<00:04, 17.37it/s]

16
17
18
19


 23%|██████████████████▋                                                              | 23/100 [00:01<00:04, 16.29it/s]

20
21
22
23


 27%|█████████████████████▊                                                           | 27/100 [00:01<00:04, 16.96it/s]

24
25
26
27


 31%|█████████████████████████                                                        | 31/100 [00:01<00:03, 17.35it/s]

28
29
30
31


 33%|██████████████████████████▋                                                      | 33/100 [00:01<00:04, 16.34it/s]

32
33
34


 37%|█████████████████████████████▉                                                   | 37/100 [00:02<00:03, 15.93it/s]

35
36
37
38


 41%|█████████████████████████████████▏                                               | 41/100 [00:02<00:03, 16.81it/s]

39
40
41
42


 45%|████████████████████████████████████▍                                            | 45/100 [00:02<00:03, 17.04it/s]

43
44
45
46


 49%|███████████████████████████████████████▋                                         | 49/100 [00:02<00:02, 17.11it/s]

47
48
49
50


 51%|█████████████████████████████████████████▎                                       | 51/100 [00:03<00:02, 17.79it/s]

51
52


 53%|██████████████████████████████████████████▉                                      | 53/100 [00:03<00:03, 12.23it/s]

53
54


 57%|██████████████████████████████████████████████▏                                  | 57/100 [00:03<00:03, 11.09it/s]

55
56
57


 61%|█████████████████████████████████████████████████▍                               | 61/100 [00:03<00:03, 12.72it/s]

58
59
60
61


 65%|████████████████████████████████████████████████████▋                            | 65/100 [00:04<00:02, 13.78it/s]

62
63
64
65


 69%|███████████████████████████████████████████████████████▉                         | 69/100 [00:04<00:02, 14.98it/s]

66
67
68
69


 73%|███████████████████████████████████████████████████████████▏                     | 73/100 [00:04<00:01, 16.94it/s]

70
71
72
73
74

 77%|██████████████████████████████████████████████████████████████▎                  | 77/100 [00:04<00:01, 17.28it/s]


75
76
77
78


 81%|█████████████████████████████████████████████████████████████████▌               | 81/100 [00:05<00:01, 17.84it/s]

79
80
81
82
83

 86%|█████████████████████████████████████████████████████████████████████▋           | 86/100 [00:05<00:00, 18.67it/s]


84
85
86
87


 90%|████████████████████████████████████████████████████████████████████████▉        | 90/100 [00:05<00:00, 17.55it/s]

88
89
90
91


 94%|████████████████████████████████████████████████████████████████████████████▏    | 94/100 [00:05<00:00, 13.44it/s]

92
93
94


 98%|███████████████████████████████████████████████████████████████████████████████▍ | 98/100 [00:06<00:00, 14.87it/s]

95
96
97
98


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.73it/s]
  2%|█▋                                                                                | 2/100 [00:00<00:05, 19.06it/s]

99
0
1
2


  6%|████▉                                                                             | 6/100 [00:00<00:05, 17.73it/s]

3
4
5
6


 10%|████████                                                                         | 10/100 [00:00<00:04, 18.02it/s]

7
8
9
10


 14%|███████████▎                                                                     | 14/100 [00:00<00:04, 17.49it/s]

11
12
13
14


 19%|███████████████▍                                                                 | 19/100 [00:01<00:04, 17.72it/s]

15
16
17
18
19


 22%|█████████████████▊                                                               | 22/100 [00:01<00:04, 17.55it/s]

20
21
22
23


 26%|█████████████████████                                                            | 26/100 [00:01<00:04, 17.85it/s]

24
25
26
27


 30%|████████████████████████▎                                                        | 30/100 [00:01<00:03, 18.53it/s]

28
29
30
31
32


 36%|█████████████████████████████▏                                                   | 36/100 [00:01<00:03, 20.68it/s]

33
34
35
36
37


 42%|██████████████████████████████████                                               | 42/100 [00:02<00:02, 21.40it/s]

38
39
40
41
42
43


 45%|████████████████████████████████████▍                                            | 45/100 [00:02<00:04, 11.67it/s]

44
45


 47%|██████████████████████████████████████                                           | 47/100 [00:02<00:04, 11.74it/s]

46
47
48


 51%|█████████████████████████████████████████▎                                       | 51/100 [00:03<00:03, 12.87it/s]

49
50
51
52


 55%|████████████████████████████████████████████▌                                    | 55/100 [00:03<00:03, 14.34it/s]

53
54
55
56


 59%|███████████████████████████████████████████████▊                                 | 59/100 [00:03<00:02, 15.54it/s]

57
58
59
60


 63%|███████████████████████████████████████████████████                              | 63/100 [00:03<00:02, 14.36it/s]

61
62
63
64

 65%|████████████████████████████████████████████████████▋                            | 65/100 [00:04<00:02, 13.50it/s]


65
66


 69%|███████████████████████████████████████████████████████▉                         | 69/100 [00:04<00:02, 14.14it/s]

67
68
69
70


 72%|██████████████████████████████████████████████████████████▎                      | 72/100 [00:04<00:01, 15.60it/s]

71
72
73
74


 77%|██████████████████████████████████████████████████████████████▎                  | 77/100 [00:04<00:01, 15.36it/s]

75
76
77


 79%|███████████████████████████████████████████████████████████████▉                 | 79/100 [00:05<00:01, 15.03it/s]

78
79
80


 83%|███████████████████████████████████████████████████████████████████▏             | 83/100 [00:05<00:01, 14.02it/s]

81
82
83


 88%|███████████████████████████████████████████████████████████████████████▎         | 88/100 [00:05<00:00, 15.69it/s]

84
85
86
87
88


 90%|████████████████████████████████████████████████████████████████████████▉        | 90/100 [00:05<00:00, 14.71it/s]

89
90
91
92


 96%|█████████████████████████████████████████████████████████████████████████████▊   | 96/100 [00:06<00:00, 15.95it/s]

93
94
95
96


 98%|███████████████████████████████████████████████████████████████████████████████▍ | 98/100 [00:06<00:00, 12.88it/s]

97
98
99


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.50it/s]
# train_split = int(len(df_train) * 0.7)  # 训练集占训练数据的0.7
# dev_split = int(len(df_train) * 0.9)  # 测试集占训练数据的0.2
# # test_split = int(len(data_df) * 0.6)  # 验证集占训练数据的0.1
train_split = int(len(sent1_) * 0.7)  # 训练集占训练数据的0.7
dev_split = int(len(sent1_) * 0.9)  # 测试集占训练数据的0.2
# test_split = int(len(data_df) * 0.6)  # 验证集占训练数据的0.1
tmp_data=pd.DataFrame()
tmp_data['text_left']=_sent1list
tmp_data['text_right']=_sent2list
tmp_data['id_left']=range(len(_sent1list))
tmp_data['id_right']=range(len(_sent2list))
# test_data['label']=_label
cls_task=mz.tasks.Classification(num_classes=2)
cls_task.metrics=[mz.metrics.Accuracy]
def load_data(df_data):
# 	df_data = pd.read_csv(data_path, sep='\t', header=None)
# 	df_data = pd.DataFrame(df_data.values, columns=['id_left', 'text_left', 'id_right', 'text_right', 'label'])
	df_data = mz.pack(df_data,task=cls_task)
	return df_data

train_data = load_data(all_data)
test_data=load_data(tmp_data)
train_split
70
train = train_data[:train_split]
dev = train_data[train_split:]
# train_pack_processed = preprocessor.fit_transform(train) 
# # 其实就是做了一个字符转id操作,所以对于中文文本,不需要分词
# dev_pack_processed = preprocessor.transform(dev) 
# dev 
train_split
70
preprocessor = mz.models.ArcI.get_default_preprocessor()
train_processed = preprocessor.fit_transform(train)
valid_processed = preprocessor.transform(dev)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 1459.16it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 1522.59it/s]
Processing text_right with append: 100%|████████████████████████████████████████████| 70/70 [00:00<00:00, 35027.59it/s]
Building FrequencyFilter from a datapack.: 100%|████████████████████████████████████| 70/70 [00:00<00:00, 17510.66it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 70/70 [00:00<00:00, 35031.77it/s]
Processing text_left with extend: 100%|█████████████████████████████████████████████| 70/70 [00:00<00:00, 23349.87it/s]
Processing text_right with extend: 100%|████████████████████████████████████████████| 70/70 [00:00<00:00, 35077.81it/s]
Building Vocabulary from a datapack.: 100%|██████████████████████████████████████| 812/812 [00:00<00:00, 812252.53it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 2122.50it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 70/70 [00:00<00:00, 1591.75it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 70/70 [00:00<00:00, 17503.36it/s]
Processing text_left with transform: 100%|██████████████████████████████████████████| 70/70 [00:00<00:00, 11674.94it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 70/70 [00:00<00:00, 11661.95it/s]
Processing length_left with len: 100%|██████████████████████████████████████████████| 70/70 [00:00<00:00, 35015.06it/s]
Processing length_right with len: 100%|█████████████████████████████████████████████| 70/70 [00:00<00:00, 35035.95it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 30/30 [00:00<00:00, 1579.85it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|█| 30/30 [00:00<00:00, 2728.30it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 30/30 [00:00<00:00, 10002.31it/s]
Processing text_left with transform: 100%|███████████████████████████████████████████| 30/30 [00:00<00:00, 2501.87it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 30/30 [00:00<00:00, 15036.94it/s]
Processing length_left with len: 100%|███████████████████████████████████████████████| 30/30 [00:00<00:00, 6001.29it/s]
Processing length_right with len: 100%|█████████████████████████████████████████████| 30/30 [00:00<00:00, 29987.87it/s]
valid_processed.unpack()
({'id_left': array([70, 71, 72, 73, 74, 77, 78, 79, 80, 81, 83, 84, 85, 87, 88, 89, 90,
         91, 92, 93, 95, 97, 98, 99]),
  'id_right': array([70, 71, 72, 73, 74, 77, 78, 79, 80, 81, 83, 84, 85, 87, 88, 89, 90,
         91, 92, 93, 95, 97, 98, 99]),
  'length_left': array([ 4,  4,  7,  7,  7,  6,  5,  5,  5,  5,  5,  5,  5,  8,  8, 10, 10,
         10, 10, 10,  2,  4,  4,  4]),
  'length_right': array([3, 2, 2, 3, 1, 2, 1, 1, 7, 2, 5, 1, 4, 5, 1, 4, 6, 3, 3, 1, 2, 7,
         1, 3]),
  'text_left': array([list([144, 197, 15, 151]), list([144, 197, 15, 151]),
         list([238, 30, 151, 49, 220, 228, 86]),
         list([238, 30, 151, 49, 220, 228, 86]),
         list([238, 30, 151, 49, 220, 228, 86]),
         list([146, 144, 198, 1, 195, 151]), list([1, 236, 220, 228, 86]),
         list([1, 236, 220, 228, 86]), list([1, 236, 220, 228, 86]),
         list([238, 166, 14, 172, 86]), list([238, 166, 14, 172, 86]),
         list([238, 166, 14, 172, 86]), list([238, 166, 14, 172, 86]),
         list([142, 117, 146, 144, 197, 15, 238, 151]),
         list([142, 117, 146, 144, 197, 15, 238, 151]),
         list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
         list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
         list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
         list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]),
         list([142, 117, 1, 1, 238, 30, 149, 170, 1, 1]), list([1, 86]),
         list([61, 197, 151, 99]), list([61, 197, 151, 99]),
         list([61, 197, 151, 99])], dtype=object),
  'text_right': array([list([142, 45, 197]), list([117, 195]), list([142, 117]),
         list([49, 83, 228]), list([195]), list([83, 195]), list([83]),
         list([195]), list([147, 170, 251, 195, 56, 170, 164]),
         list([78, 197]), list([146, 22, 30, 75, 38]), list([110]),
         list([195, 56, 188, 117]), list([146, 215, 142, 80, 30]),
         list([38]), list([256, 12, 137, 142]),
         list([8, 177, 70, 195, 7, 177]), list([244, 195, 122]),
         list([202, 42, 136]), list([195]), list([217, 86]),
         list([238, 115, 151, 166, 250, 195, 137]), list([142]),
         list([221, 236, 166])], dtype=object)},
 array([[1],
        [0],
        [0],
        [1],
        [0],
        [1],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [0],
        [0],
        [0],
        [0]]))
train_processed.unpack()
({'id_left': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34,
         35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
         52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
         69]),
  'id_right': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34,
         35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
         52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
         69]),
  'length_left': array([ 6,  6,  6,  2,  2,  2,  2,  2, 10, 10, 10, 10, 10,  3,  3,  3,  9,
          9,  9,  4,  4,  4,  4,  4,  9,  9,  9,  7,  7,  7,  5,  5,  3,  3,
          3,  3,  3,  3,  3,  3,  3,  3,  9,  9,  2,  2,  2,  8,  8,  8,  4,
          4,  4,  9,  9,  9,  9,  2,  2,  2,  5,  5,  5,  3,  3,  3,  3,  3,
          4]),
  'length_right': array([11,  2,  4,  6,  2,  9,  7,  6,  8,  4,  9,  2, 10,  7,  5, 10,  1,
         18, 27,  3,  5,  4,  8, 12,  4,  2, 16,  8,  1,  5,  8,  5,  2,  7,
          3,  6,  7,  9,  1,  4, 10, 10,  7,  6,  2,  2,  5,  9, 12, 13,  9,
          2,  2,  8,  5,  9,  7,  3,  5,  2,  8,  5,  5,  5,  4,  8, 11,  5,
          2]),
  'text_left': array([list([249, 14, 126, 166, 68, 87]),
         list([249, 14, 126, 166, 68, 87]),
         list([249, 14, 126, 166, 68, 87]), list([180, 86]),
         list([180, 86]), list([180, 86]), list([180, 86]), list([180, 86]),
         list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
         list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
         list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
         list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
         list([57, 195, 58, 111, 166, 108, 90, 114, 166, 108]),
         list([263, 131, 99]), list([263, 131, 99]), list([263, 131, 99]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
         list([239, 206, 108, 252]), list([239, 206, 108, 252]),
         list([239, 206, 108, 252]), list([239, 206, 108, 252]),
         list([239, 206, 108, 252]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 33]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 33]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 33]),
         list([170, 151, 83, 133, 146, 198, 86]),
         list([170, 151, 83, 133, 146, 198, 86]),
         list([170, 151, 83, 133, 146, 198, 86]),
         list([62, 27, 117, 166, 86]), list([62, 27, 117, 166, 86]),
         list([67, 172, 195]), list([67, 172, 195]), list([67, 172, 195]),
         list([67, 172, 195]), list([67, 172, 195]), list([45, 143, 88]),
         list([45, 143, 88]), list([45, 143, 88]), list([45, 143, 88]),
         list([45, 143, 88]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 86]),
         list([238, 115, 154, 49, 236, 170, 50, 208, 86]), list([234, 86]),
         list([234, 86]), list([234, 86]),
         list([165, 236, 138, 78, 197, 15, 238, 115]),
         list([165, 236, 138, 78, 197, 15, 238, 115]),
         list([165, 236, 138, 78, 197, 15, 238, 115]),
         list([197, 150, 161, 86]), list([197, 150, 161, 86]),
         list([197, 150, 161, 86]),
         list([238, 30, 127, 247, 95, 30, 128, 90, 31]),
         list([238, 30, 127, 247, 95, 30, 128, 90, 31]),
         list([238, 30, 127, 247, 95, 30, 128, 90, 31]),
         list([238, 30, 127, 247, 95, 30, 128, 90, 31]), list([265, 108]),
         list([265, 108]), list([265, 108]), list([49, 168, 55, 108, 75]),
         list([49, 168, 55, 108, 75]), list([49, 168, 55, 108, 75]),
         list([248, 141, 88]), list([248, 141, 88]), list([248, 141, 88]),
         list([248, 141, 88]), list([248, 141, 88]),
         list([144, 197, 15, 151])], dtype=object),
  'text_right': array([list([173, 249, 210, 128, 254, 175, 173, 253, 159, 121, 119]),
         list([166, 195]), list([238, 166, 7, 177]),
         list([100, 65, 230, 229, 195, 129]), list([166, 88]),
         list([238, 115, 14, 177, 133, 259, 195, 142, 198]),
         list([152, 246, 166, 42, 136, 107, 195]),
         list([117, 195, 142, 61, 197, 23]),
         list([142, 166, 264, 115, 237, 40, 115, 88]),
         list([155, 231, 24, 10]),
         list([73, 199, 207, 90, 147, 195, 171, 90, 153]), list([117, 195]),
         list([256, 112, 12, 194, 166, 13, 85, 139, 260, 218]),
         list([178, 255, 230, 83, 221, 236, 195]),
         list([170, 14, 188, 263, 131]),
         list([81, 170, 245, 14, 115, 180, 195, 27, 263, 131]), list([170]),
         list([182, 186, 183, 195, 160, 104, 5, 133, 225, 198, 78, 86, 50, 189, 137, 109, 43, 191]),
         list([124, 250, 219, 74, 103, 7, 8, 84, 214, 163, 202, 235, 42, 136, 96, 4, 177, 76, 244, 59, 29, 200, 75, 197, 150, 170, 251]),
         list([142, 198, 239]), list([3, 177, 211, 225, 195]),
         list([142, 192, 161, 86]),
         list([83, 61, 214, 22, 133, 142, 197, 15]),
         list([142, 205, 15, 215, 142, 80, 257, 142, 188, 242, 130, 224]),
         list([49, 89, 82, 228]), list([142, 117]),
         list([142, 117, 142, 36, 91, 195, 154, 149, 137, 51, 142, 170, 164, 78, 198, 86]),
         list([123, 222, 233, 170, 115, 11, 135, 195]), list([209]),
         list([83, 133, 142, 197, 15]),
         list([62, 262, 117, 79, 223, 246, 166, 83]),
         list([56, 190, 101, 243, 198]), list([166, 195]),
         list([238, 30, 166, 40, 172, 195, 216]), list([236, 170, 195]),
         list([40, 172, 246, 69, 26, 47]),
         list([14, 172, 195, 170, 63, 195, 38]),
         list([147, 25, 179, 167, 192, 37, 220, 197, 150]), list([94]),
         list([110, 166, 241, 195]),
         list([60, 37, 52, 176, 145, 49, 204, 170, 17, 185]),
         list([28, 238, 203, 156, 151, 185, 208, 246, 27, 109]),
         list([238, 115, 151, 49, 236, 170, 208]),
         list([66, 226, 197, 142, 48, 162]), list([125, 195]),
         list([21, 195]), list([142, 144, 39, 15, 45]),
         list([165, 78, 198, 92, 2, 116, 118, 37, 198]),
         list([146, 134, 56, 174, 23, 98, 69, 71, 146, 197, 23, 212]),
         list([146, 46, 232, 44, 197, 160, 38, 213, 38, 43, 113, 187, 38]),
         list([238, 115, 151, 102, 6, 177, 166, 32, 177]), list([56, 117]),
         list([18, 148]), list([128, 166, 201, 19, 266, 35, 106, 128]),
         list([31, 166, 140, 184, 195]),
         list([258, 195, 170, 261, 16, 158, 77, 31, 240]),
         list([169, 195, 166, 261, 90, 16, 158]), list([264, 115, 150]),
         list([265, 72, 120, 9, 132]), list([73, 64]),
         list([181, 227, 54, 53, 129, 27, 70, 108]),
         list([4, 177, 151, 75, 38]), list([238, 115, 151, 83, 93]),
         list([260, 83, 197, 150, 195]), list([46, 170, 164, 86]),
         list([146, 157, 251, 83, 196, 193, 197, 195]),
         list([117, 195, 130, 217, 34, 238, 105, 126, 20, 122, 86]),
         list([81, 197, 41, 250, 86]), list([83, 97])], dtype=object)},
 array([[1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1]]))
trainset = mz.dataloader.Dataset(
    data_pack=train_processed,
    mode='point',
#     num_dup=1,
#     num_neg=4,
    batch_size=32
)
validset = mz.dataloader.Dataset(
    data_pack=valid_processed,
    mode='point',
    batch_size=32
)
padding_callback = mz.models.ArcI.get_default_padding_callback()

trainloader = mz.dataloader.DataLoader(
    dataset=trainset,
    stage='train',
    callback=padding_callback
)
validloader = mz.dataloader.DataLoader(
    dataset=validset,
    stage='dev',
    callback=padding_callback
)

模型搭建

model = mz.models.ArcI()
model.params['task'] = cls_task
model.params['embedding_output_dim'] = 100
model.params['embedding_input_dim'] = preprocessor.context['embedding_input_dim']
model.guess_and_fill_missing_params()
model.build()
print(model)
ArcI(
  (embedding): Embedding(267, 100, padding_idx=0)
  (conv_left): Sequential(
    (0): Sequential(
      (0): ConstantPad1d(padding=(0, 2), value=0)
      (1): Conv1d(100, 32, kernel_size=(3,), stride=(1,))
      (2): ReLU()
      (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  )
  (conv_right): Sequential(
    (0): Sequential(
      (0): ConstantPad1d(padding=(0, 2), value=0)
      (1): Conv1d(100, 32, kernel_size=(3,), stride=(1,))
      (2): ReLU()
      (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (mlp): Sequential(
    (0): Sequential(
      (0): Linear(in_features=1760, out_features=128, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
    )
    (3): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): ReLU()
    )
  )
  (out): Linear(in_features=64, out_features=2, bias=True)
)
optimizer = torch.optim.Adam(model.parameters())

trainer = mz.trainers.Trainer(
    model=model,
    optimizer=optimizer,
    trainloader=trainloader,
    validloader=validloader,
    epochs=10
)

trainer.run()
HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-3 Loss-0.709]:
  Validation: accuracy: 0.625




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-6 Loss-0.684]:
  Validation: accuracy: 0.6667




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-9 Loss-0.603]:
  Validation: accuracy: 0.6667




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-12 Loss-0.508]:
  Validation: accuracy: 0.6667




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-15 Loss-0.379]:
  Validation: accuracy: 0.6667




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-18 Loss-0.389]:
  Validation: accuracy: 0.6667




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-21 Loss-0.438]:
  Validation: accuracy: 0.75




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-24 Loss-0.327]:
  Validation: accuracy: 0.8333




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-27 Loss-0.298]:
  Validation: accuracy: 0.8333




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


[Iter-30 Loss-0.255]:
  Validation: accuracy: 0.75

Cost time: 4.359504461288452s
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

南楚巫妖

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值