Siamese+LSTM
网络结构模块
-
孪生网络
首先要理解什么是孪生网络模块,我们在词嵌入和编码(LSTM)过程中使用的是相同的参数,比如说我下面的代码中,在编码阶段,对于两个句子的输入,我都使用了相同的LSTM,这也就是Siamese+LSTM。
-
疑问
不过我有个疑问,我在某些文章中看到用LSTM来判断句子相似度,它不是Siamese+LSTM,在文章中对比的两种方法,一种是Siamese+LSTM,还有一种是LSTM。我也不清楚这个单独的LSTM是什么意思,后来我思考了一下,可能是在对两个句子进行编码的时候,使用了两个不同的LSTM结构(我猜是这样)。 -
编码
编码完之后,也就是LSTM的输出阶段。LSTM的输出尺寸是[len,batch,hidden],因为在LSTM结构中我没有去声明batch_first这个属性,对LSTM的输出,我选取了最后一个输出,相当于final_ht。 -
距离
两个尺寸都是[batch,hidden],然后对两个句子取绝对值的差作为全连接神经网络的输入,torch.abs(),这里有个dim参数有时候有用,我在这里没写,正常情况下应该是dim=-1。 -
sigmoid()函数
关于最后为什么使用sigmoid函数,我最终是把输出的值通过sigmoid函数归到[0,1],最终loss函数采用的nn.BCEloss()。当然这里你也可以进行二分类,我试了一下,感觉没有输出[0,1]之间的效果好,虽然效果都不是很好。。。
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
file1 = r"./data/vocab.pkl"
file2 = r"./data/vocabs_matrix.pkl"
vocabs = pickle.load(open(file1,'rb'))
Embedding_matrix = pickle.load(open(file2,'rb'))
Vocab_size = len(vocabs)
class LSTM1(nn.Module):
def __init__(self):
super(LSTM1, self).__init__()
self.Vocab_size = Vocab_size
self.batch_size = 500
self.input_size = 300
self.n_hidden1 = 128
self.Embedding_dim = 300
self.n_class=2
self.seq_len = 20
self.dropout = nn.Dropout(0.2)
self.Embedding_matrix = Embedding_matrix
self.word_embeds = nn.Embedding(self.Vocab_size+1, self.Embedding_dim)
pretrained_weight = np.array(self.Embedding_matrix)
self.word_embeds.weight.data.copy_(torch.from_numpy(pretrained_weight))
self.Lstm1 = nn.LSTM(self.Embedding_dim, hidden_size=self.n_hidden1, bidirectional=False)
#self.fc = nn.Linear(self.n_hidden1*2,self.n_class,bias=False)
self.fc1 = nn.Linear(self.n_hidden1,32,bias=False)
self.b1 = nn.Parameter(torch.rand([32]))
self.fc2 = nn.Linear(32,1,bias=False)
self.b2 = nn.Parameter(torch.rand([1]))
pass
def forward(self,train_left,train_right):
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
train_left = self.word_embeds(train_left).to(device)
train_right = self.word_embeds(train_right).to(device)
train_left = train_left.transpose(0,1)
train_right = train_right.transpose(0,1)
hidden_state1 = torch.rand(1,self.batch_size,self.n_hidden1).to(device)
cell_state1 = torch.rand(1,self.batch_size,self.n_hidden1).to(device)
outputs1_L,(final_state1_L,_) =self.Lstm1(train_left,(hidden_state1,cell_state1))
outputs1_L = self.dropout(outputs1_L)
outputs1_R,(final_state1_R,_) =self.Lstm1(train_right,(hidden_state1,cell_state1))
outputs1_R = self.dropout(outputs1_R)
outputs1 = outputs1_L[-1]
outputs2 = outputs1_R[-1]
output = torch.abs(outputs1-outputs2)
output = self.fc1(output)+self.b1
output = self.dropout(output)
output = self.fc2(output)+self.b2
output = torch.sigmoid(output)
return output
pass
train模块
def train(model, device, train_dataloader, optimizer, epoch):
model.train()
train_loss = 0
num_correct = 0
for batch_idx,(train_left,train_right,lables) in enumerate(train_dataloader):
train_left = train_left.to(device)
train_right = train_right.to(device)
lables = lables.to(device)
optimizer.zero_grad()
output = model(train_left,train_right)
output = output.view_as(lables)
loss = loss_fn(output,lables)
loss.backward()
optimizer.step()
train_loss += float(loss.item())
true = lables.data.cpu()
predict = torch.round(output).cpu()
num_correct += torch.eq(predict, true).sum().float().item()
total_len = len(train_dataloader.dataset)
train_acc = num_correct / len(train_dataloader.dataset)
train_loss = train_loss/ len(train_dataloader)
print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t Acc: {:.6f}'.format(epoch,
batch_idx * len(train_left),
len(train_dataloader.dataset),
100. * batch_idx / len(train_dataloader),
train_loss,
train_acc))
结果(数据集采用的是Quora Question Pairs)
D:\Anaconda3\envs\python36\python.exe D:/TextMatching/NewProject2/train.py
Learning rate: 0.01
Epochs: 10
Training on 363861 samples...
Train epoch: 1 [363000/363861 (100%)] Loss: 0.450285 Acc: 0.787457
Make prediction for 40429 samples...
F1 improved at epoch: 1 ; best_F1:0.75543 ; best_Accuracy:0.79368 ; best_Precision:0.67190 ; best_Recall:0.86267
Training on 363861 samples...
Train epoch: 2 [363000/363861 (100%)] Loss: 0.368117 Acc: 0.833332
Make prediction for 40429 samples...
F1:0.75543 Accuracy:0.79368 Precision:0.67190 Recall:0.86267 No improvement since epoch: 1
Training on 363861 samples...
Train epoch: 3 [363000/363861 (100%)] Loss: 0.336102 Acc: 0.848426
Make prediction for 40429 samples...
F1 improved at epoch: 3 ; best_F1:0.76229 ; best_Accuracy:0.79602 ; best_Precision:0.66912 ; best_Recall:0.88559
Training on 363861 samples...
Train epoch: 4 [363000/363861 (100%)] Loss: 0.320470 Acc: 0.855816
Make prediction for 40429 samples...
F1 improved at epoch: 4 ; best_F1:0.77047 ; best_Accuracy:0.81200 ; best_Precision:0.70121 ; best_Recall:0.85491
Training on 363861 samples...
Train epoch: 5 [363000/363861 (100%)] Loss: 0.308956 Acc: 0.862206
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4
Training on 363861 samples...
Train epoch: 6 [363000/363861 (100%)] Loss: 0.302345 Acc: 0.864965
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4
Training on 363861 samples...
Train epoch: 7 [363000/363861 (100%)] Loss: 0.296583 Acc: 0.867658
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4
Training on 363861 samples...
Train epoch: 8 [363000/363861 (100%)] Loss: 0.293175 Acc: 0.868862
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4
Training on 363861 samples...
Train epoch: 9 [363000/363861 (100%)] Loss: 0.288697 Acc: 0.871360
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4
Training on 363861 samples...
Train epoch: 10 [363000/363861 (100%)] Loss: 0.287799 Acc: 0.871404
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4
Process finished with exit code 0
最后,如果是做文本相似方向的可以和我一起交流一下,我也是个新手。。