在实际 应用中,我们通常采用一个已经训练模型的模型的权值参数作为我们模型的初始化参数,
也称之为
Finetune
,更宽泛的称之为迁移学习。迁移学习中的
Finetune
技术,本质上就是
让我们新构建的模型,拥有一个较好的权值初始值。
finetune
权值初始化三步曲,
finetune
就相当于给模型进行初始化,其流程共用三步:
第一步:保存模型,拥有一个预训练模型;
第二步:加载模型,把预训练模型中的权值取出来;
第三步:初始化,将权值对应的
“
放
”
到新模型中
1.代码
1)加载数据
import torchvision.transforms as transforms
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
#重置图像分辨率为32x32
transforms.Resize(32),
#上下左右均填充4个像素,然后随机裁剪32*32
transforms.RandomCrop(32, padding=4),
#先对数据进行转置,将h*w*c变成c*h*w
#然后将所有像素除以255,使得像素归一化为[0-1]
transforms.ToTensor(),
#对图像进行标准化
normTransform
])
validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
import os
import sys
sys.path.append( "./../util/") #工具类的相对位置
from torch.utils.data import DataLoader
from utils import MyDataset
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
# 构建MyDataset实例
train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)
#批次大小
train_bs = 16
valid_bs = 16
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)
2)定义网络
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) #输入3 输出6 kernel5 3x32x32 -> 6x28x28
self.pool1 = nn.MaxPool2d(2, 2) # kernel 2 stride 2 6x28x28 -> 6x14x14
self.conv2 = nn.Conv2d(6, 16, 5) #输出6 输入16 kernel 5 6x14x14 -> 16x10x10
self.pool2 = nn.MaxPool2d(2, 2) # 16x10x10 -> 16x5x5
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 16 * 5 * 5 -> 120
self.fc2 = nn.Linear(120, 84) #120 -> 84
self.fc3 = nn.Linear(84, 10) #84 -> 10
#conv1->relu->pool1->conv2->relu->pool2->fc1->relu->fc2->relu->fc3
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义权值初始化
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data) #正态分布
if m.bias is not None:
m.bias.data.zero_() #偏置全部归0
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) # 权重全部归1
m.bias.data.zero_() # 偏置全部归0
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01) #正态分布初始化
m.bias.data.zero_() #偏置全部归0
net = Net() # 创建一个网络
3)finetune 权值初始化
#加载pkl
pretrained_dict = torch.load(r'E:\\pytorch_learning\\Data\\net_params.pkl')
# 获取当前网络的dict
net_state_dict = net.state_dict()
# 剔除不匹配的权值参数
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
# 更新新模型参数字典
net_state_dict.update(pretrained_dict_1)
# 将包含预训练模型参数的字典"放"到新模型中
net.load_state_dict(net_state_dict)
2.效果
通过对字典的打印,我们发现字典的key就是网络层,value就是对应的权重参数
接下来的参数在下面几章