编写不易如果觉得不错,麻烦关注一下~
代码段来自https://github.com/linjieli222/VQA_ReGAT/blob/master/model/language_model.py
调用的库里的lstm 和gru 模型框架,
其中forward 是将两个隐层变量拼接在一起
forward_all 返回所有
官网解释:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM
class QuestionEmbedding(nn.Module):
def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout,
rnn_type='GRU'):
"""Module for question embedding
"""
super(QuestionEmbedding, self).__init__()
assert rnn_type == 'LSTM' or rnn_type == 'GRU'
rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU \
if rnn_type == 'GRU' else None
self.rnn = rnn_cls(
in_dim, num_hid, nlayers,
bidirectional=bidirect,
dropout=dropout,
batch_first=True)
self.in_dim = in_dim
self.num_hid = num_hid
self.nlayers = nlayers
self.rnn_type = rnn_type
self.ndirections = 1 + int(bidirect)
def init_hidden(self, batch):
# just to get the type of tensor
weight = next(self.parameters()).data
hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
if self.rnn_type == 'LSTM':
return (weight.new(*hid_shape).zero_(),
weight.new(*hid_shape).zero_())
else:
return weight.new(*hid_shape).zero_()
def forward(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
self.rnn.flatten_parameters()
output, hidden = self.rnn(x, hidden)
if self.ndirections == 1:
return output[:, -1]
forward_ = output[:, -1, :self.num_hid]
backward = output[:, 0, self.num_hid:]
return torch.cat((forward_, backward), dim=1)
def forward_all(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
self.rnn.flatten_parameters()
output, hidden = self.rnn(x, hidden)
return output