5. 加载预训练模型与冻结解冻模型参数示例
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307),(0.3081,1))])
train_data = datasets.MNIST(root ="..\\mnist\\", train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
val_data = datasets.MNIST(root ="..\\mnist\\", train=False, transform=transform, download=True)
val_dataloader = DataLoader(dataset=val_data, batch_size=64, shuffle=False)classMyLeNet(nn.Module):def__init__(self):super(MyLeNet, self).__init__()
self.feature = nn.Sequential(
nn.Conv2d(1,16,5),
nn.MaxPool2d(2,2),
nn.Conv2d(16,32,5),
nn.MaxPool2d(2,2))
self.fc1 = nn.Linear(32*4*4,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)defforward(self, x):
x = self.feature(x)
x = x.view(-1,32*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)return x
deftrain(epoch):
loss_runtime =0.0for batch, data inenumerate(tqdm(train_loader,0)):
x, y = data
x = x.to(device)
y = y.to(device)
y_pred = model(x)
loss = criterion(y_pred, y)
loss_runtime += loss.item()
loss_runtime /= x.size(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()print("after % epochs, loss is %.8f"%(epoch +1, loss_runtime))#state_dict是python字典对象,它将每一层映射到其参数张量,只有具有可学习参数的层#才具有state_dict
save_file ={"model": model.state_dict(),"optimizer": optimizer.state_dict(),"epoch": epoch
}
torch.save(save_file,"model_{}.pth".format(epoch))defval():
correct, total =0,0with torch.no_grad():for(x , y)in val_loader:
x = x.to(device)
y = y.to(device)
y_pred = model(x)
_, pred = torch.max(y_pred.data, dim=1)
correct +=(pred == y).sum().item()
total += y.size(0)
acc = correct / total
print("accuracy on val set is : %5f"% acc)if __name__ =="__main__":
start_epoch =0
freeze_epoch =0
resume ="..\\002模型冻结解冻\\lenet5_pretrained_weight.pt"#lenet5预训练权重链接为https://github.com/SteveJRZ
freeze =True
model = MyLeNet()
device =("cuda:0"if torch.cuda.is_available()else"cpu")
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)#加载预训练权重if resume isTrue:
checkpoint = torch.load(resume, map_loaction="cpu")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']#冻结训练if freeze:
freeze_epoch =5print("冻结前置特征提取网络权重,训练后面的全连接层")for param in model.feature.parameters():
param.requires_grad =False#将不更新的参数requires_grad设置为False,
optimizer = torch.optim.SGD(filter(lambda p: p.requires_gard, model.parameters()), lr=0.01, momentum=0.5)for epoch inrange(start_epoch, start_epoch + freeze_epoch):
train(epoch)
val()print("解冻前置特征提取网络权重,接着训练整个网络权重")for param in mode.feature.parameters():
param.requires_grad =True
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr =0.01, momentum=0.5)
6. 补充知识点
#(1)torchvision#pytorch中torchvision里已有很多常用模型,可直接调用如alexnet、vgg、densenet等import torchvision.models as models
alexnet = models.alexnet()
resnet18 = models.resnet18()#(2)加载预训练模型#这是我们自己网络模型参数的有序字典形式(网络参数名:值)
net_dict = net.state_dict()#这是实际加载的预训练好的网络模型参数的有序字典形式
pretrained_dict = torch.load(pretrained_path)#从预训练的参数中加载我们的网络中需要的模型参数(这个很重要、有时需要冻结某一层的参数、可用这条语句从预训练的整个网络参数中筛选出我们需要的某一层的参数)
pretrained_dict ={k: v for k, v in pretrained_dict.items()if k in net_dict}#字典的updata方法,进行字典的更新(个人感觉不是必要的)
net_dict.update(pretrained_dict)#按照键与键的对应关系、加载网络参数
net.load_state_dict(net_dict)