关于PyTorch中模型保存与加载的问题

博主分享了在模型加载、数据转换和梯度爆炸问题上的经验,包括模型保存与nn.Flatten/torch.flatten的区别,以及未归一化导致的预测偏差。他还揭示了解决导入代码自动训练的方法和重要细节调整。
摘要由CSDN通过智能技术生成

我真服了,明明这么简单的东西,居然卡了我一个晚上,哭死

1.模型保存与加载推理的问题

代码如下,很简单,直接用既保持模型,又保存checkpoints的就好了,虽然不推荐,但是简单啊!

torch.save(SncNet, 'final.pth')
model = torch.load("best.pth")

2.关于nn.Flatten()与torch.flatten()(懒得写了,如下图所示)

由于我的模型是全部由全连接层构成,输入数据是4*3,因此网络第一层必然是nn.Flatten(),但是nn.Flatten()出现在网络的定义中,因此它是对于第一维度的拉伸,反之,torch.flatten()是对于第零维度的拉伸。因此要想输入一个4*3的数据去已保存的模型中,首先第一步是需要把他reshape成1*4*3的格式,或者在网络输入输入的时候直接将数据输入成1*12的格式。

3.在做预测模型的时候,出现了预测出全是0和1的情况,很显然是因为经过了sigmoid之后,梯度爆炸,导致最后两极分化,如下图所示,这是我经过断点之后,查看sigmoid之前的数据,发现产生了梯度爆炸(谢谢我的彬神,告诉我该怎么debug呜呜呜),但是再训练的时候看输出明明好好的,所以问题肯定出在了数据导入,紧接着就发现,我忘记对数据做归一化,导致梯度爆炸!!! 

  

4.在导入模型的时候,他直接给我再训练了一次,这是因为我直接import了我神经网络的代码,而我却在main里面进行训练,所以在import之后计算机自动给我运行了main,所以就又训练了一次。。。我麻了。。。解决办法就是,要么def train,要么就只复制网络结构。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值