AI算法与应用-01深度学习框架-002模型冻结与解冻

#pytorch冻结网络参数

1. 动机与意义
#(1)避免过拟合:当训练数据较少时,神经网络容易过拟合,冻结一些参数可减少网络的自由度,避免过拟合
#(2)加速训练:冻结某些参数时,冻结的参数在训练时不需更新,可减少训练时间
#(3)迁移学习:通常使用预训练好的模型来初始化目标任务的模型.冻结预训练模型的一部分参数可以保持这些参数
#的特增提取能力不变,只训练目标任务的部分参数
#(4)稳定网络:在生成对抗网络中,冻结判别器的参数可以保证生成器更容易生成真实样本

2. 哪些层数会参与梯度更新
#(1)一般模型参数的冻结都是针对全连接层和卷积层而言,因为这些层通常具有大量的可学习参数,而且在训练过程中
#容易过拟合,通过冻结这些层的参数,可减少模型中需要学习的参数数量,从而降低过拟合风险,提高模型泛化性.
#池化层、归一化层等通常具有较少可学习参数,因此他们的参数很少被冻结.但有时也会对其冻结,例如,使用预训练
#模型微调时,通常会冻结训练模型中的所有层,并只对新添加的全连接层训练,以充分利用预训练模型的提取能力

3. 冻结模型参数示例
import torch
import torch.nn as nn

#定义一个模型
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

#冻结第一层(两种方法)
model[0].requires_grad = False
model[0].requires_grad_(False)

#冻结第一个线性层
for param in model[0].parameters():
    print(param)
    param.requires_grad = False

#冻结前两层
for param in model[:2].parameters():
    param.requires_grad = False

#验证参数状态
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")
#验证哪些参数被冻结,哪些未被冻结
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}的参数正在被训练")
    else:
        print(f"{name}的参数已被冻结")

4. 模型冻结实现思路

#在加载预训练权重后,可能需要固定一部分模型的参数,只更新另一部分参数。三种实现思#路
#(1)一个是设置不要更新参数的网络层为reuires_grad=False,
#(2)另一个为在定义优化器时只传入要更新的参数。
#(3)最优方式:将不更新的参数requires_grad设置为False,同时传入该参数传入#optimizer

#不冻结层时
class net(nn.Module):
    def __init__(self, num_class=10):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)
    
    def forward(self, x):
        return self.fc2(self.fc1(x))

model = net()

loss_fun = nn.CrossEntropyLoss()
print(list(model.parameters())[-3:]) #输出模型最后三层的参数
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0, 10, [3]).long()
    output = model(x)

    loss = loss_fun(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
#冻结fc1层
#(1)将优化器传入所有的参数
#(2)将要冻结层的参数requires_grad设置为False
loss_fun = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for name, param in model.named_parameters():
    if "fc1" in name:
        param.requires_grad = False

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0, 10, [3]).long()
    output = model(x)

    loss = loss_fun(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
#方式二冻结fc1层
#优化器传入不冻结的fc2层的参数
loss_fun = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc2.parameters(), lr=1e-2)
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0, 3, [3]).long()
    output = model(x)

    loss = loss_fun(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
#将不更新的参数requires_grad设置为False,同时不将该参数传入optimizer
#将不更新的参数传入optimizer,节省显存
#将不更新的参数requires_grad设置为False,节省计算部分参数梯度的时间
loss_fun =  nn.CrossEntropyLoss()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print("model.fc1.weight.requires_grad", model.fc1.weight.requires_grad)
print("model.fc2.weight.requires_grad", model.fc2.weight.requires_grad)

for name, param in model.named_parameters():
    if "fc1" in name:
        param.requires_grad = False

optimizer = torch.optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0, 3, [3]).long()
    output = model(x)

    loss = loss_fun(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print("model.fc1.weight.requires_grad", model.fc1.weight.requires_grad)
print("model.fc2.weight.requires_grad", model.fc2.weight.requires_grad)

5. 加载预训练模型与冻结解冻模型参数示例

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081, 1))])
train_data = datasets.MNIST(root = "..\\mnist\\", train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
val_data = datasets.MNIST(root = "..\\mnist\\", train=False, transform=transform, download=True)
val_dataloader = DataLoader(dataset=val_data, batch_size=64, shuffle=False)

class MyLeNet(nn.Module):
    def __init__(self):
        super(MyLeNet, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(1, 16, 5),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = nn.Linear(32 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.feature(x)
        x = x.view(-1, 32*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train(epoch):
    loss_runtime = 0.0
    for batch, data in enumerate(tqdm(train_loader, 0)):
        x, y = data
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss_runtime += loss.item()
        loss_runtime /= x.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print("after % epochs, loss is %.8f" % (epoch + 1, loss_runtime))
    #state_dict是python字典对象,它将每一层映射到其参数张量,只有具有可学习参数的层
    #才具有state_dict
    save_file = {"model": model.state_dict(),
                 "optimizer": optimizer.state_dict(),
                 "epoch": epoch
    }

    torch.save(save_file, "model_{}.pth".format(epoch))

def val():
    correct, total = 0, 0
    with torch.no_grad():
        for (x , y) in val_loader:
            x = x.to(device)
            y = y.to(device)
            y_pred = model(x)
            _, pred = torch.max(y_pred.data, dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            acc = correct / total
    print("accuracy on val set is : %5f" % acc)

if __name__ == "__main__":
    start_epoch = 0
    freeze_epoch = 0
    resume = "..\\002模型冻结解冻\\lenet5_pretrained_weight.pt"
    #lenet5预训练权重链接为https://github.com/SteveJRZ
    freeze = True

    model = MyLeNet()
    device = ("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    #加载预训练权重
    if resume is True:
        checkpoint = torch.load(resume, map_loaction="cpu")
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

        #冻结训练
        if freeze:
            freeze_epoch = 5
            print("冻结前置特征提取网络权重,训练后面的全连接层")
            for param in model.feature.parameters():
                param.requires_grad = False #将不更新的参数requires_grad设置为False,
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_gard, model.parameters()), lr=0.01, momentum=0.5)
            for epoch in range(start_epoch, start_epoch + freeze_epoch):
                train(epoch)
                val()
            print("解冻前置特征提取网络权重,接着训练整个网络权重")
            for param in mode.feature.parameters():
                param.requires_grad = True
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = 0.01, momentum=0.5)
6. 补充知识点
#(1)torchvision
#pytorch中torchvision里已有很多常用模型,可直接调用如alexnet、vgg、densenet等
import torchvision.models as models
alexnet = models.alexnet()
resnet18 = models.resnet18()
#(2)加载预训练模型
#这是我们自己网络模型参数的有序字典形式(网络参数名:值)
net_dict = net.state_dict()
#这是实际加载的预训练好的网络模型参数的有序字典形式    
pretrained_dict = torch.load(pretrained_path)
#从预训练的参数中加载我们的网络中需要的模型参数(这个很重要、有时需要冻结某一层的参数、可用这条语句从预训练的整个网络参数中筛选出我们需要的某一层的参数)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net_dict}
#字典的updata方法,进行字典的更新(个人感觉不是必要的)
net_dict.update(pretrained_dict)
#按照键与键的对应关系、加载网络参数    
net.load_state_dict(net_dict)
  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值