最近在看论文《Minimal Gated Unit for Recurrent Neural Networks》,简言之,就是把两个门换成一个门,花了不少时间写,记录一下。
日后类似修改神经元可参考这个。
MGU单元:
class NaiveCustomMGU(nn.Module):
def __init__(self, input_sz: int, hidden_sz: int):
super().__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
# f_t
self.W_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.U_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
# h_hat_t
self.W_h_h = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.U_h_h = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_h_h = nn.Parameter(torch.Tensor(hidden_sz))
self.b2_h_h = nn.Parameter(torch.Tensor(hidden_sz))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x, init_states=None):
"""
assumes x.shape represents (batch_size, sequence_size, input_size)
"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (
torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device),
)
else:
h_t, c_t = init_states
for t in range(seq_sz):
x_t = x[:, t, :]
f_t = torch.sigmoid(x_t @ self.W_f + h_t @ self.U_f + self.b_f)
h_hat_t = torch.tanh(x_t @ self.W_h_h + self.b_h_h + torch.mul(f_t,(torch.mm(h_t,self.U_h_h) + self.b2_h_h)))
h_t = torch.mul((1-f_t), h_t) + torch.mul(f_t,h_hat_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
在设计模型中实现双向:
class BiMGU(nn.Module):
def reverse(x):
x = x.to(DEVICE)
# reverse first dimension
# first - create a tensor with reversed dims
idx = [i for i in range(x.size(0) - 1, -1, -1)]
idx = torch.LongTensor(idx).to(DEVICE)
# create tensor with reversed dims
inverted_tensor = x.index_select(0, idx).to(DEVICE)
return inverted_tensor
def __init__(self, max_words,input_size, hidden_size, emb_size, out_size):
# print("MGUTest __init__ Start")
super(BiMGU, self).__init__()
self.emb_size = emb_size
self.max_words = max_words
self.hid_size = hidden_size
self.Embedding = nn.Embedding(self.max_words, self.emb_size).to(DEVICE)
self.mgu = NaiveCustomMGU(input_size, hidden_size).to(DEVICE)
self.fc1 = nn.Linear(self.hid_size * 2, self.hid_size)
self.fc2 = nn.Linear(self.hid_size, 2).to(DEVICE)
def forward(self, x):
# print("MGUTest forward Start")
x = self.Embedding(x).to(DEVICE) # [bs, ml, emb_size]
y_1, h_1 = self.mgu(x)
x_t = BiMGU.reverse(x).to(DEVICE)
y_2, h_2 = self.mgu(x_t)
y_2 = BiMGU.reverse(y_2).to(DEVICE)
x = torch.cat((y_1, y_2), 2)
# x, _ = self.LSTM(x) # [bs, ml, 2*hid_size]
# x = self.dp(x)
x = F.relu(self.fc1(x)) # [bs, ml, hid_size]
x = F.avg_pool2d(x, (x.shape[1], 1)).squeeze() # [bs, 1, hid_size] => [bs, hid_size]
out = self.fc2(x).to(DEVICE) # [bs, 2]
# y = y.squeeze(1)
#y = (y[:, -1, :])
# print("y:", y.shape)
return out