源代码
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.lstm = nn.LSTM(input_size=1, hidden_size=256, batch_first=True)
self.norm = nn.LayerNorm(256)
def forward(self, x):
print(x.shape) # torch.Size([32, 10000])
output, (hidden, cell) = self.lstm(x)
output = self.norm(output)
return hidden
报错:raise RuntimeError(
RuntimeError: input must have 3 dimensions, got 2
维度不匹配问题,修改后的代码如下:
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.lstm = nn.LSTM(input_size=1, hidden_size=256, batch_first=True)
self.norm = nn.LayerNorm(256)
def forward(self, x):
print(x.shape) # torch.Size([32, 10000])
x = x.view(len(x), 1, -1) # 维度变为 torch.Size([32, 1, 10000])
print(x.shape)
output, (hidden, cell) = self.lstm(x)
output = self. Norm(output)