使用自己的训练集进行图像分类是出现
RuntimeError: Shape not matching: the Program requires a parameter with a shape of ((2048, 9)), while the loaded parameter (namely [ cls_out_w ]) has a shape of ((2048, 1)).
我的分类类别是9个,我猜测可能是分类类别什么地方出错了。为了验证我的猜想,我减少了数据集中的一个类别
RuntimeError: Shape not matching: the Program requires a parameter with a shape of ((2048, 8)), while the loaded parameter (namely [ cls_out_w ]) has a shape of ((2048, 1)).
原来问题出在预训练模型上
把预训练模型的参数删除就可以,因为分类数不一样,有些参数不能用了