torch实现ESIM算法

本文档介绍了如何用Torch实现ESIM(Enhanced LSTM for Text Inference)算法,模型主要分为五个步骤,包括编码、注意力、交互、解码和整合,通过这些步骤实现文本推理。
摘要由CSDN通过智能技术生成

ESIM的模型代码见:

# -*- coding: utf-8 -*-
# @Time : 2020/2/25 11:19
# @Author : liusen
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from data_helper import TEXT
from data_helper import train_iter, val_iter

weight_matrix = TEXT.vocab.vectors

class ESIM(nn.Module):
	def __init__(self):
		super(ESIM, self).__init__()
		self.dropout = 0.5
		self.hidden_size = 128
		self.embeds_dim = 300
		self.embeds = nn.Embedding(len(TEXT.vocab), self.embeds_dim)
		self.bn_embeds = nn.BatchNorm1d(self.embeds_dim)
		self.lstm1 = nn.LSTM(self.embeds_dim, self.hidden_size, batch_first=True, bidirectional=True)
		self.lstm2 = nn.LSTM(self.hidden_size * 8, self.hidden_size, batch_first=True, bidirectional=True)

		self.fc = nn.Sequential(
			nn.BatchNorm1d(self.hidden_size * 8),
			nn.Linear(self.hidden_size * 8, 2),
			nn.ELU(inplace=True),
			nn.BatchNorm1d(2),
			nn.Dropout(self.dropout),
			nn.Linear(2, 2),
			nn.ELU(inplace=True),
			nn.BatchNorm1d(2),
			nn.Dropout(self.dropout),
			nn.Linear(2, 2),
			nn.Softmax(dim=-1)
		)

	def soft_attention_align(self, x1, x2, mask1, mask2):
		'''
		x1: batch_size * seq_len * dim
		x2: batch_size * seq_len * dim
		'''
		# attention: batch_size * seq_len * seq_len
		attention = torch.matmul(x1, x2.transpose(1, 2))
		# mask1 = mask1.float().masked_fill_(mask1, float('
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值