在迁移学习或训练一个新的复杂模型时,部分加载模型或部分加载模型是常见的场景。利用训练过的参数,即使只有少数是可用的,也将有助于热身训练过程,并有望帮助您的模型比从头开始训练更快地收敛。
介绍
无论您是从缺少一些keys的部分state_dict加载的,还是加载比您加载的模型keys更多的state_dict,都可以在load_state_dict()函数中设置严格参数为False,以忽略不匹配的key。在这个食谱中,我们将实验使用不同模型的参数来预热一个模型。
步骤
1. 导入包
2. 定义和初始化神经网络A和B
3. 保存模型A
4. 加载模型B
1. Import necessary libraries for loading our data
import torch
import torch.nn as nn
import torch.optim as optim
2. Define and intialize the neural network A and B
class NetA(nn.Module):
def __init__(self):
super(NetA, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netA = NetA()
class NetB(nn.Module):
def __init__(self):
super(NetB, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netB = NetB()
3. Save model A
# Specify a path to save to
PATH = "model.pt"
torch.save(netA.state_dict(), PATH)
4. Load into model B
netB.load_state_dict(torch.load(PATH), strict=False)