保存模型
def main(
if valid_loss < best_loss:
is_best = True
best_epoch = epoch
best_prec = min(valid_loss,best_loss)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_prec': best_prec,
'optimizer': optimizer.state_dict(),
}, is_best, fdir)
def save_checkpoint(state, is_best, fdir):
filepath = os.path.join(fdir, 'checkpoint.pth')
torch.save(state, filepath)
if is_best:
torch.save(state, os.path.join(fdir, 'model_best.pth.tar'))
logger.info("Best model is updated")
查看更改模型
for key in model.state_dict():
print(key, model.state_dict()[key].size())
for key in checkpoint:
print(key, checkpoint[key].shape)
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
keys = list(checkpoint.keys())
for key in keys:
if 'thresh' in key:
checkpoint[key[:-6] + 'up'] = checkpoint.pop(key)
#名字全是乱的,但是参数数量一样一一对应:
key_index = list(model.state_dict().keys())
for item, (key, value) in enumerate(checkpoint.items()):
model_key = key_index[item]
model.load_state_dict({model_key: value}, strict=False)
加载模型
model = modelpool(args.model, args.dataset)
start_epoch = 0
best_prec = 0
quan.replace_modules(model, args)
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.resume.path:
# model, start_epoch, _ = util.load_checkpoint(
# model, args.resume.path, 'cuda', lean=args.resume.lean)
checkpoint = torch.load(args.resume.path) # 加载保存的检查点文件
model.load_state_dict(checkpoint['state_dict']) # 加载模型的状态字典
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器的状态字典
start_epoch = checkpoint['epoch'] # 获取加载的 epoch 数
best_prec = checkpoint['best_prec'] # 获取加载的最佳准确率
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()
print(f"start epoch: {start_epoch}, lr: {optimizer.param_groups[0]['lr']}, best_prec: {best_prec}")
model.cuda()
adjust_lr
def adjust_learning_rate(optimizer, epoch, lr):
lr *= (0.1 ** (epoch // 60))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def lr_scheduler(optimizer, epoch):
lr_list = [100, 140, 240]
if epoch in lr_list:
print('change the learning rate')
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.1
for epoch in range(start_epoch, args.epoch):
adjust_learning_rate(optimizer, epoch, args.lr)
lr = optimizer.param_groups[0]['lr']
cosine:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
SBN in validation
class SBN(nn.Module):
def __init__(self, bn, T) -> None:
super().__init__()
self.eps = bn.eps
self.weight = bn.weight
self.bias = bn.bias / T
self.val_mean = bn.running_mean / T
self.val_var = bn.running_var
def forward(self, x):
if self.training:
val_mean = x.mean([0, 2])
val_var = x.var([0, 2], unbiased=False)
x = x - self.val_mean[None, ..., None]
x = x / torch.sqrt(self.val_var[None, ..., None] + self.eps)
x = x * self.weight[..., None] + self.bias[..., None]
return x
with open(f"model_params.txt", "w") as fo:
fo.write(f"{len(model.state_dict())}")
for key in model.state_dict():
fo.write(f"{key}\t{model.state_dict()[key].size()}\n")
with open(f"checkpoint_params.txt", "w") as f:
f.write(f"{len(checkpoint)}")
for key in checkpoint:
f.write(f"{key}\t{checkpoint[key].shape}\n")