项目场景:
最近在复现一篇AAAI2020年的论文《Improving ECG Classification Using Generative Adversarial Networks》,记录一下碰到的比较棘手的bug。
问题描述:
在跑代码时,碰到了一个网络输入与网络维度不匹配的问题,图片如下:
问题是说网络的维度为216X100,但是传入网络的参数维度为2150X5,所以就报错了。
解决方案:
作为一个接触机器学习不久的新手,改正这个错误花了很多时间。具体改正的办法如下文。
将Classifiers文件夹下的main.py中的
dataset = ecg_dataset_pytorch.EcgHearBeatsDataset(transform=composed)
……
testset = ecg_dataset_pytorch.EcgHearBeatsDatasetTest(transform=composed)
改为
dataset = ecg_dataset_pytorch.EcgHearBeatsDataset(transform=composed, lstm_setting=False)
……
testse