1. 固定随机种子 和 cudnn.benchmark = true
import os
import torch
import torchvision
import math
import random
import numpy as np
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
2. 初始化模型,设置优化器和学习率调整算法
model = torchvision.models.resnet18()
model.fc = torch.nn.Linear(512, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.95, nesterov=True)
steps = 10
lrf = 0.01
my_lambda = lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (lrf - 1) + 1
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=my_lambda)
3. 模型训练过程
start_epoch = 0
inputs = torch.rand((2,3,32,32))
for epoch in range(start_epoch+1,20):
output = model(inputs)
loss = torch.nn.functional.cross_entropy(output, torch.tensor([0,1], dtype=torch.long))
print(loss)
loss.backward()
optimizer.zero_grad()
optimizer.step()
scheduler.step()
4. 保存模型权重、优化器参数、以及参数跟新策略
checkpoint = {
"model": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
'lr_schedule': scheduler.state_dict()
}
if not os.path.isdir("./model_parameter"):
os.mkdir("./model_parameter")
torch.save(checkpoint, './model_parameter/check.pth')
5. 加载模型权重、优化器参数、以及参数跟新策略
path_checkpoint = "./model_parameter/check.pth"
checkpoint = torch.load(path_checkpoint)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
scheduler.load_state_dict(checkpoint['lr_schedule'])
6. 整体过程如下:
import os
import torch
import torchvision
import math
import random
import numpy as np
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model = torchvision.models.resnet18()
model.fc = torch.nn.Linear(512, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.95, nesterov=True)
steps = 100
lrf = 0.01
my_lambda = lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (lrf - 1) + 1
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=my_lambda)
start_epoch = 0
resume = True
if resume:
path_checkpoint = "./model_parameter/check.pth"
checkpoint = torch.load(path_checkpoint)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
scheduler.load_state_dict(checkpoint['lr_schedule'])
inputs = torch.rand((2,3,32,32))
for epoch in range(start_epoch+1,start_epoch+10):
output = model(inputs)
loss = torch.nn.functional.cross_entropy(output, torch.tensor([0,1], dtype=torch.long))
print(loss)
loss.backward()
optimizer.zero_grad()
optimizer.step()
scheduler.step()
checkpoint = {
"model": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
'lr_schedule': scheduler.state_dict()
}
if not os.path.isdir("./model_parameter"):
os.mkdir("./model_parameter")
torch.save(checkpoint, './model_parameter/check.pth')