问题描述
提示:这里描述项目中遇到的问题:
使用LSTM网络中出现:
Traceback (most recent call last):
File "C:/Users/96552/Desktop/DFMNET/main.py", line 184, in <module>
dfmnet.fit(train_x, train_y)
File "C:/Users/96552/Desktop/DFMNET/main.py", line 138, in fit
y_pred = self(x)
File "C:\Users\96552\.conda\envs\pytorchcpu\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:/Users/96552/Desktop/DFMNET/main.py", line 108, in forward
r = r[-1, :, :]
TypeError: tuple indices must be integers or slices, not tuple
维度错误,数据维度跟网络为度不一致,重点检查 def forward(),发现pytorch自带的LSTM
def forward(self, x):
r, _ = self.LSTM(x.transpose(1, 0))
return y
改成下面后就不报错:
def forward(self, x):
r, _ = self.LSTM(x.transpose(1, 0))
return y