这个函数实现了一个多层长短期记忆网络(LSTM)的前向传播。LSTM是一种递归神经网络(RNN),特别擅长处理和预测基于时间的数据。下面是详细解释该函数中的每一部分代码。
初始化方法(__init__)
class LSTM(RNNBase):
def __init__(self, *args, **kwargs):
super().__init__('LSTM', *args, **kwargs)
__init__
方法初始化LSTM层。它调用了基类RNNBase
的初始化方法,并传递参数'LSTM'
和其他参数。
获取期望的细胞大小(get_expected_cell_size)
def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = int(batch_sizes[0])
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (self.num_layers * num_directions,
mini_batch, self.hidden_size)
return expected_hidden_size
- 该方法计算并返回LSTM层的预期隐藏状态大小。根据输入张量的维度和是否为双向LSTM计算mini_batch的大小。
检查前向传播参数(check_forward_args)
def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]):
self.check_input(input, batch_sizes)
self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
'Expected hidden[0] size {}, got {}')
self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
'Expected hidden[1] size {}, got {}')
- 该方法检查输入张量和隐藏状态的尺寸是否符合预期。它调用了其他方法来检查输入和隐藏状态的尺寸。
置换隐藏状态(permute_hidden)
def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
if permutation is None:
return hx
return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
- 该方法根据给定的置换张量重新排列隐藏状态。如果没有提供置换张量,则返回原始隐藏状态。
前向传播方法(forward)
def forward(self, input, hx=None):
self._update_flat_weights()
orig_input = input
batch_sizes = None
do_permute = False
num_directions = 2 if self.bidirectional else 1
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
if hx is None:
h_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, real_hidden_size,
dtype=input.dtype, device=input.device)
c_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (h_zeros, c_zeros)
else:
hx = self.permute_hidden(hx, sorted_indices)
else:
if input.dim() not in (2, 3):
raise ValueError(f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead")
is_batched = input.dim() == 3
batch_dim = 0 if self.batch_first else 1
if not is_batched:
input = input.unsqueeze(batch_dim)
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
if hx is None:
h_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, real_hidden_size,
dtype=input.dtype, device=input.device)
c_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (h_zeros, c_zeros)
self.check_forward_args(input, hx, batch_sizes)
else:
if is_batched:
if (hx[0].dim() != 3 or hx[1].dim() != 3):
msg = ("For batched 3-D input, hx and cx should "
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
raise RuntimeError(msg)
else:
if hx[0].dim() != 2 or hx[1].dim() != 2):
msg = ("For unbatched 2-D input, hx and cx should "
f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
raise RuntimeError(msg)
hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
self.check_forward_args(input, hx, batch_sizes)
hx = self.permute_hidden(hx, sorted_indices)
if batch_sizes is None:
result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1:]
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched:
output = output.squeeze(batch_dim)
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
return output, self.permute_hidden(hidden, unsorted_indices)
代码详解:
-
初始化:
self._update_flat_weights()
更新扁平化权重。- 初始化变量
orig_input
,batch_sizes
,do_permute
,num_directions
和real_hidden_size
。
-
处理PackedSequence输入:
- 如果输入是
PackedSequence
类型,将其解包为input
,batch_sizes
,sorted_indices
,unsorted_indices
。 - 根据输入大小初始化
hx
(隐藏状态)。
- 如果输入是
-
处理非PackedSequence输入:
- 检查输入维度是否为2D或3D。
- 如果是2D,将其转换为3D。
- 初始化
hx
(隐藏状态)。
-
检查和调整输入和隐藏状态:
- 调用
check_forward_args
检查输入和隐藏状态的尺寸。 - 调整隐藏状态的排列顺序。
- 调用
-
调用底层LSTM实现:
- 根据是否存在
batch_sizes
,选择合适的底层LSTM实现。
- 根据是否存在
-
处理输出:
- 如果输入是
PackedSequence
类型,将结果重新打包。 - 否则,调整输出和隐藏状态的维度。
- 如果输入是