attention+pytorch+时间序列数据预测

将用于NLP的Encoder-Decoder修改用于时间序列数据预测,实验发现添加注意力机制后预测效果能够得到提升。
class Encoder (nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.rnn=nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=HIDDEN_SIZE,
            num_layers= 1,
            batch_first=True
        )
    def forward(self,x):
        r_out, (hidden,cell) = self.rnn(x)
        print(r_out.shape)
        return r_out,hidden,cell

class Decoder (nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.rnn=nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=HIDDEN_SIZE,
            num_layers= 1,
            batch_first=True
        )
        self.out=nn.Linear(HIDDEN_SIZE,1)
    def forward(self,x,hidden,cell):
        print("x:", x.shape)
        output, (hidden,cell) = self.rnn(x,(hidden,cell))
        print("output:", output.shape)
        print("output.squeeze(0):", output.squeeze(0).shape)
        prediction = self.out(output.squeeze(0))
        print("prediction:",prediction.shape)
        return  prediction,hidden,cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def attention_net(self, lstm_output, final_state):

        hidden = final_state.view(-1, HIDDEN_SIZE , 1) # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]
        # print("----------------------------------------------------")
        # print("hidden的值:", hidden.shape)
        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]
        # print("attn_weights的值:", attn_weights.shape)
        soft_attn_weights = F.softmax(attn_weights, 1)
        # print("soft_attn_weights的值:", soft_attn_weights.shape)
        # print("soft_attn_weights.unsqueeze(2)的值:", soft_attn_weights.unsqueeze(2).shape)
        # print("torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2))的值:", torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).shape)
        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
        print("context的值:", context.shape)
        return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]


    def forward(self,src):

        src_len=src.shape[0]
        batch_size = src.shape[1]
        outputs =torch.zeros(src_len, batch_size, 1).to(self.device).double()
        # print("------------------------------")
        # print("outputs:",outputs.shape)
        print(src.shape)
        r_out,hidden,cell = self.encoder(src)

        print("r_out",r_out.shape)
        print("hidden", hidden.shape)

        attn_output, attention = self.attention_net(r_out, hidden)
        hidden = attn_output.view(1, -1, HIDDEN_SIZE)

        # print("hidden___",hidden.shape)
        # print("attn_output",attn_output.shape)
        # print("attention", attention.shape)
        # print("------------------------------")
        # print("src:", src.shape)
        # print("hidden:",hidden.shape)
        # print("cell:",cell.shape)
        # print("------------------------------")

        for t in range(1,batch_size):
            input=src[:,t-1,:].unsqueeze(1)
            print("input:",input.shape)
            output, hidden, cell = self.decoder(input, hidden, cell)
            print("------------------------------")
            print("output:",output.shape)
            print("hidden:",  hidden.shape)
            print("cell:", cell.shape)
            print("outputs:", outputs.shape)
            print("outputs[:,t,:]:", outputs[:,t-1,:].unsqueeze(1).shape)
            outputs[:,t-1,:]=output.squeeze(1)
        print("------------------------------")
        print("outputs:",outputs.shape)
        return outputs

  • 6
    点赞
  • 57
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值