![9f1729e567a8ef0a4e4274ecffda53b4.png](https://img-blog.csdnimg.cn/img_convert/9f1729e567a8ef0a4e4274ecffda53b4.png)
在用pytorch写代码的时候,经常遇到要临时储存模型参数,后续再加载该模型的情况。例如预训练模型之后,肯定要将预训练的模型存储一下,以便后续训练过程再载入该模型。这边常见的做法是利用torch.save load load_state_dict这三个函数进行处理。
torch.save({'model': model.state_dict()}, 'zhenghuo.model')
# 将model的参数储存到zhenghuo.model里面
config = torch.load('zhenghuo.model')
newmodel.load_state_dict(config['model'], strict=False)
# 假设我们新建了一个newmodel实例,将之前储存的参数加载进来
但是这里我产生了一个巨大疑问。假设我的模型里面有两个全连接层,那加载的时候怎么区分这两个线性层是不一样的呢。
这边我设置一个有两个全连接层的Model,直接读取其参数列表,看看会输出什么。
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(3, 3)
a = Model()
print(a.state_dict())
![28f9f9d6fc7df2b5aff8d6cf2529f09d.png](https://img-blog.csdnimg.cn/img_convert/28f9f9d6fc7df2b5aff8d6cf2529f09d.png)
可以看到state_dict把我设置的全连接层的名字也存下来了,linear1和linear2分别储存各自的权重和偏移量,这样就避免了歧义的问题。
那研究中还可能遇到一种情况(真实经历):我辛辛苦苦用机子跑了一个礼拜,得到了一个很满意的模型。结果我突然发现我要改一下model结构。例如加一个全连接层或者删除一个全连接层。这样模型结构就发生了变化。不一样的model是否还可以加载参数呢。
这里举一个例子。我们假设一开始用的是newModel,里面只有一个全连接层linear1。后来我突发奇想要再加一个全连接层linear2,也就是Model。这样我拿newModel的参数加载给Model,会怎么样呢?
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(3, 3)
class newModel(nn.Module):
def __init__(self):
super(newModel, self).__init__()
self.linear1 = nn.Linear(2, 2)
![5cc513df66c2e86acf3352cb74151b79.png](https://img-blog.csdnimg.cn/img_convert/5cc513df66c2e86acf3352cb74151b79.png)
可以看到newModel存下来的参数只有linear1。我把它存在了zhenghuo.model里面(整活)
a = Model()
print(a.state_dict())
c = torch.load('zhenghuo.model')
a.load_state_dict(c['model'], strict=False) # strict划重点
print(a.state_dict())
a是一个Model的实例,我们分别读取它在加载参数之前和之后的state_dict()
![5b74014866acbced47e489f292379c0e.png](https://img-blog.csdnimg.cn/img_convert/5b74014866acbced47e489f292379c0e.png)
结果是只有linear1的权重和偏移量被覆盖了,linear2保持随机初始化出来的数据(我的zhenghuo.model里面也根本没有关于linear2的任何信息)。其中load_state_dict有一个参数strict,将其设置成true则会报错,报错信息是没有linear2的参数。也就是说当strict=True的时候,只有之前保存的模型和新模型一模一样,才可以加载参数。
我再把Model的两个全连接层名字更改一下
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear2 = nn.Linear(2, 2)
self.linear3 = nn.Linear(3, 3)
class newModel(nn.Module):
def __init__(self):
super(newModel, self).__init__()
self.linear1 = nn.Linear(2, 2)
可以看到现在Model里面已经不包含linear1这个全连接层了,按我的理解,即使再读取newModel的参数,也不会对Model的参数产生影响。
![1770ffc8e98effd237c85de53647663e.png](https://img-blog.csdnimg.cn/img_convert/1770ffc8e98effd237c85de53647663e.png)
可以看到读取参数前后,模型参数确实没变。
也就是说,假设我预训练好了一个模型,储存了该模型的参数。后续当我想要对模型进行拓展,直接加载之前保存的参数也是没有问题的。需要注意的是原模型的部件不能再被修改,只能增加或者删除。如果更改了某个部件例如一个全连接层的名字,那么在加载过程中该部件将不会继承预训练模型的数据。如果真的不想要这个全连接层删除掉就好了。