LSTM源代码详解

这个函数实现了一个多层长短期记忆网络(LSTM)的前向传播。LSTM是一种递归神经网络(RNN),特别擅长处理和预测基于时间的数据。下面是详细解释该函数中的每一部分代码。

初始化方法(__init__)

class LSTM(RNNBase):
    def __init__(self, *args, **kwargs):
        super().__init__('LSTM', *args, **kwargs)

__init__方法初始化LSTM层。它调用了基类RNNBase的初始化方法,并传递参数'LSTM'和其他参数。

获取期望的细胞大小(get_expected_cell_size)

def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
    if batch_sizes is not None:
        mini_batch = int(batch_sizes[0])
    else:
        mini_batch = input.size(0) if self.batch_first else input.size(1)
    num_directions = 2 if self.bidirectional else 1
    expected_hidden_size = (self.num_layers * num_directions,
                            mini_batch, self.hidden_size)
    return expected_hidden_size
  • 该方法计算并返回LSTM层的预期隐藏状态大小。根据输入张量的维度和是否为双向LSTM计算mini_batch的大小。

检查前向传播参数(check_forward_args)

def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]):
    self.check_input(input, batch_sizes)
    self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
                           'Expected hidden[0] size {}, got {}')
    self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
                           'Expected hidden[1] size {}, got {}')
  • 该方法检查输入张量和隐藏状态的尺寸是否符合预期。它调用了其他方法来检查输入和隐藏状态的尺寸。

置换隐藏状态(permute_hidden)

def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
    if permutation is None:
        return hx
    return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
  • 该方法根据给定的置换张量重新排列隐藏状态。如果没有提供置换张量,则返回原始隐藏状态。

前向传播方法(forward)

def forward(self, input, hx=None):
    self._update_flat_weights()

    orig_input = input
    batch_sizes = None
    do_permute = False
    num_directions = 2 if self.bidirectional else 1
    real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
    
    if isinstance(orig_input, PackedSequence):
        input, batch_sizes, sorted_indices, unsorted_indices = input
        max_batch_size = batch_sizes[0]
        if hx is None:
            h_zeros = torch.zeros(self.num_layers * num_directions,
                                  max_batch_size, real_hidden_size,
                                  dtype=input.dtype, device=input.device)
            c_zeros = torch.zeros(self.num_layers * num_directions,
                                  max_batch_size, self.hidden_size,
                                  dtype=input.dtype, device=input.device)
            hx = (h_zeros, c_zeros)
        else:
            hx = self.permute_hidden(hx, sorted_indices)
    else:
        if input.dim() not in (2, 3):
            raise ValueError(f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead")
        is_batched = input.dim() == 3
        batch_dim = 0 if self.batch_first else 1
        if not is_batched:
            input = input.unsqueeze(batch_dim)
        max_batch_size = input.size(0) if self.batch_first else input.size(1)
        sorted_indices = None
        unsorted_indices = None
        if hx is None:
            h_zeros = torch.zeros(self.num_layers * num_directions,
                                  max_batch_size, real_hidden_size,
                                  dtype=input.dtype, device=input.device)
            c_zeros = torch.zeros(self.num_layers * num_directions,
                                  max_batch_size, self.hidden_size,
                                  dtype=input.dtype, device=input.device)
            hx = (h_zeros, c_zeros)
            self.check_forward_args(input, hx, batch_sizes)
        else:
            if is_batched:
                if (hx[0].dim() != 3 or hx[1].dim() != 3):
                    msg = ("For batched 3-D input, hx and cx should "
                           f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
                    raise RuntimeError(msg)
            else:
                if hx[0].dim() != 2 or hx[1].dim() != 2):
                    msg = ("For unbatched 2-D input, hx and cx should "
                           f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
                    raise RuntimeError(msg)
                hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
            self.check_forward_args(input, hx, batch_sizes)
            hx = self.permute_hidden(hx, sorted_indices)

    if batch_sizes is None:
        result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
                          self.dropout, self.training, self.bidirectional, self.batch_first)
    else:
        result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
                          self.num_layers, self.dropout, self.training, self.bidirectional)
    output = result[0]
    hidden = result[1:]
    if isinstance(orig_input, PackedSequence):
        output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
        return output_packed, self.permute_hidden(hidden, unsorted_indices)
    else:
        if not is_batched:
            output = output.squeeze(batch_dim)
            hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
        return output, self.permute_hidden(hidden, unsorted_indices)
代码详解:
  1. 初始化

    • self._update_flat_weights() 更新扁平化权重。
    • 初始化变量orig_input, batch_sizes, do_permute, num_directionsreal_hidden_size
  2. 处理PackedSequence输入

    • 如果输入是PackedSequence类型,将其解包为input, batch_sizes, sorted_indices, unsorted_indices
    • 根据输入大小初始化hx(隐藏状态)。
  3. 处理非PackedSequence输入

    • 检查输入维度是否为2D或3D。
    • 如果是2D,将其转换为3D。
    • 初始化hx(隐藏状态)。
  4. 检查和调整输入和隐藏状态

    • 调用check_forward_args检查输入和隐藏状态的尺寸。
    • 调整隐藏状态的排列顺序。
  5. 调用底层LSTM实现

    • 根据是否存在batch_sizes,选择合适的底层LSTM实现。
  6. 处理输出

    • 如果输入是PackedSequence类型,将结果重新打包。
    • 否则,调整输出和隐藏状态的维度。
  • 6
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值