lstm的简单实现

做个笔记,以防忘记
在这里插入图片描述

import torch
import torch.nn as nn


class myLstm(nn.Module):
    # input_sz为字向量维度
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz

        self.U_i = nn.Linear(input_sz, hidden_sz)
        self.V_i = nn.Linear(hidden_sz, hidden_sz)

        # f_t
        self.U_f = nn.Linear(input_sz, hidden_sz)
        self.V_f = nn.Linear(hidden_sz, hidden_sz)

        # c_t
        self.U_g = nn.Linear(input_sz, hidden_sz)
        self.V_g = nn.Linear(hidden_sz, hidden_sz)

        # o_t
        self.U_o = nn.Linear(input_sz, hidden_sz)
        self.V_o = nn.Linear(hidden_sz, hidden_sz)

    def forward(self, x, init_states=None):
        # x.size() = batch_size * len * embedding_dim
        x = x.float()
        bs, seq_sz, _ = x.size()
        hidden_seq = []

        # 初始化隐变量h0,c0
        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_size),
                torch.zeros(bs, self.hidden_size)
            )
        else:
            h_t, c_t = init_states

        # 对所有批次中每句第t个字
        for t in range(seq_sz):
            x_t = x[:, t, :]

            # 获得i_t
            i_t = torch.sigmoid(self.U_i(x_t) + self.V_i(h_t))
            # 获得f_t
            f_t = torch.sigmoid(self.U_f(x_t) + self.V_f(h_t))
            # 获得g_t
            g_t = torch.tanh(self.U_g(x_t) + self.V_g(h_t))
            # 获得o_t
            o_t = torch.sigmoid(self.U_o(x_t) + self.V_o(h_t))
            # 以上四个都以x_t和h_t为输入量,且输出都为 batch_size * hidden_size

            # 此处后面的c_t实际为c_t-1,然后获得当前c_t
            c_t = f_t * c_t + i_t * g_t
            # 用o_t和得到的c_t相乘得当前h_t
            h_t = o_t * torch.tanh(c_t)

            # 记录每次的h_t的值
            hidden_seq.append(h_t.unsqueeze(0))

        # 将此时的hidden_seq(为含seq_len个元素的list, 每个元素为(1, batch_size, hidden_size))拼接为(seq_len, batch_size, hidden_size)
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # 转换为(batch_size, seq_len, hidden_size)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)


if __name__ == '__main__':
	# 测试
    batch_size = 100
    hidden_size = 64
    embedding_dim = 32
    min_id = 1
    max_id = 3000

    model = myLstm(embedding_dim, hidden_size)
    input = torch.randint(min_id, max_id, (batch_size, 50, embedding_dim))
    print(model)
    print(model(input))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值