苦恼了将近半个月的问题终于解决了!!!!!!
本人刚接触机器学习没多久,前段时间在训练模型时觉得如果每次用这个模型时都需要重新去训练有些麻烦,就使用了网上大家都在用的网络模型参数保存与加载的方法,但我却在加载模型后进行测试时发现我的网络模型参数根本就没有被正确保存,搞了几天之后仍然没有找到原因,今天终于找到了原因,所以记录一下。
model.load_state_dict(torch.load(path))
torch.save(model.state_dict(), path)
以上两行代码是网上很容易就找到的模型参数保存与加载的代码
然而,敲重点!!!
上面这两行代码仅仅适用于网络参数在初始化时与模型进行了绑定的情况,也就是
w = nn.Parameter(torch.ones(10))这样使用nn.Parameter进行初始化的,如果仅仅是初始化了张量,而没有使用nn.Parameter则这个参数与模型没有绑定,也就无法通过上述方法进行保存参数。
那么如果并不想通过nn.Parameter将参数与模型进行绑定的话,应该怎么进行模型保存与加载呢?
其实也很简单,只需要将参数保存下来,然后再传回网络就可以了
torch.save(模型参数,保存路径)
params=torch.load(保存路径)
当然此时模型类在定义时需要写一个能将模型参数通过外部数据进行赋值的函数,通过该函数将params传进去即可。
仅个人理解,欢迎批评指正