ESIM的模型代码见:
# -*- coding: utf-8 -*-
# @Time : 2020/2/25 11:19
# @Author : liusen
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from data_helper import TEXT
from data_helper import train_iter, val_iter
weight_matrix = TEXT.vocab.vectors
class ESIM(nn.Module):
def __init__(self):
super(ESIM, self).__init__()
self.dropout = 0.5
self.hidden_size = 128
self.embeds_dim = 300
self.embeds = nn.Embedding(len(TEXT.vocab), self.embeds_dim)
self.bn_embeds = nn.BatchNorm1d(self.embeds_dim)
self.lstm1 = nn.LSTM(self.embeds_dim, self.hidden_size, batch_first=True, bidirectional=True)
self.lstm2 = nn.LSTM(self.hidden_size * 8, self.hidden_size, batch_first=True, bidirectional=True)
self.fc = nn.Sequential(
nn.BatchNorm1d(self.hidden_size * 8),
nn.Linear(self.hidden_size * 8, 2),
nn.ELU(inplace=True),
nn.BatchNorm1d(2),
nn.Dropout(self.dropout),
nn.Linear(2, 2),
nn.ELU(inplace=True),
nn.BatchNorm1d(2),
nn.Dropout(self.dropout),
nn.Linear(2, 2),
nn.Softmax(dim=-1)
)
def soft_attention_align(self, x1, x2, mask1, mask2):
'''
x1: batch_size * seq_len * dim
x2: batch_size * seq_len * dim
'''
# attention: batch_size * seq_len * seq_len
attention = torch.matmul(x1, x2.transpose(1, 2))
# mask1 = mask1.float().masked_fill_(mask1, float('