import torch.nn as nn
import math
import torch
class GateConcMechanism(nn.Module):
def __init__(self, hidden_size=None):
super(GateConcMechanism, self).__init__()
self.hidden_size = hidden_size
self.w1 = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
self.w2 = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
self.bias = nn.Parameter(torch.Tensor(self.hidden_size))
self.reset_parameters()
def reset_parameters(self): # 作用
stdv1 = 1. / math.sqrt(self.w1.size(1))
stdv2 = 1. / math.sqrt(self.w2.size(1))
stdv = (stdv1 + stdv2) / 2.
self.w1.data.uniform_(-stdv1, stdv1)
self.w2.data.uniform_(-stdv2, stdv2)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, hidden):
# input: hidden state from encoder;
# hidden: hidden state from key value memory network
# output = [gate * input; (1 - gate) * hidden]
gated = input.matmul(self.w1.t()) + hidden.matmul(self.w2.t()) + self.bias # input*w1 + hidden*w2 + bias
gate = torch.sigmoid(gated)
# output = torch.add(input.mul(gate), hidden.mul(1 - gate))
output = torch.cat([input.mul(gate), hidden.mul(1 - gate)],dim=-1)
return output
gate mechanism
于 2021-02-08 10:47:20 首次发布