首先把模型打印出来,具体可以看我的上一篇(103条消息) pth文件网络,结构可视化_我来保fu小仙女的博客-CSDN博客
每一层网络前面都有(xxx):yyy()
修改前
def forward(self,x):
x=self.model.xxx1(x)
x=self.model.xxx2(x)
x=self.model.xxx3(x)
例如不想使用xxx2层
修改为
def forward(self,x):
x=self.model.xxx1(x)
加入一些transform保证尺寸符合下一层的输入
x=self.model.xxx3(x)
!!!前提是把输入输出的通道和尺寸计算准确,保证符合网络运行标准。例如采用1*1卷积升降通道数。resize函数改变尺寸等等