目录
4、开始训练,训练模型、损失函数的保存,验证的评价指标的保存
网络的训练过程如下:
1、加载数据,包括:训练集、验证集
1)确定设备
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True # 利用cudnn尝试优化运行速度
2)确定数据和标签归一化方式
# Normalize data and label to [-1, 1]
transform_data = Compose([
T.LogTransform(k=args.k),
T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=args.k), T.log_transform(ctx['data_max'], k=args.k))
])
transform_label = Compose([
T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
])
3)初始化训练集和验证集
if args.train_anno[-3:] == 'txt':
dataset_train = FWIDataset(
args.train_anno,
preload=True,
sample_ratio=args.sample_temporal,
file_size=ctx['file_size'],
transform_data=transform_data,
transform_label=transform_label
)
else:
dataset_train = torch.load(args.train_anno)
print('Loading validation data')
if args.val_anno[-3:] == 'txt':
dataset_valid = FWIDataset(
args.val_anno,
preload=True,
sample_ratio=args.sample_temporal,
file_size=ctx['file_size'],
transform_data=transform_data,
transform_label=transform_label
)
else:
dataset_valid = torch.load(args.val_anno)
4) 加载数据集
train_sampler = RandomSampler(dataset_train)
valid_sampler = RandomSampler(dataset_valid)
dataloader_train = DataLoader(
dataset_train, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers,
pin_memory=True, drop_last=True, collate_fn=default_collate) # default_collate 将样本列表转成批次张量
dataloader_valid = DataLoader(
dataset_valid, batch_size=args.batch_size,
sampler=valid_sampler, num_workers=args.workers,
pin_memory=True, collate_fn=default_collate)
2、定义损失函数、学习率、优化器
1)损失函数
l1loss = nn.L1Loss()
l2loss = nn.MSELoss()
def criterion(pred, gt):
loss_g1v = l1loss(pred, gt)
loss_g2v = l2loss(pred, gt)
loss = args.lambda_g1v * loss_g1v + args.lambda_g2v * loss_g2v
return loss, loss_g1v, loss_g2v
2)学习率
# Scale lr according to effective batch size
lr = args.lr * args.world_size
# Convert scheduler to be per iteration instead of per epoch
warmup_iters = args.lr_warmup_epochs * len(dataloader_train)
lr_milestones = [len(dataloader_train) * m for m in args.lr_milestones]
lr_scheduler = WarmupMultiStepLR(
optimizer, milestones=lr_milestones, gamma=args.lr_gamma,
warmup_iters=warmup_iters, warmup_factor=1e-5)
3)优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=args.weight_decay)
3、加载模型, 这里考虑是否加载预训练模型
1)加载模型
model = network.model_dict[args.model](upsample_mode=args.up_mode,
sample_spatial=args.sample_spatial,
sample_temporal=args.sample_temporal).to(device)
model_without_ddp = model # 不采用分布式训练
2)加载预训练模型
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(network.replace_legacy(checkpoint['model']))
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
step = checkpoint['step']
lr_scheduler.milestones=lr_milestones
4、开始训练,训练模型、损失函数的保存,验证的评价指标的保存
1)训练模型
train_one_epoch(model, criterion, optimizer, lr_scheduler, dataloader_train,
device, epoch, args.print_freq, train_writer)
train_one_epoch 函数的作用:计算损失、反向传播、更新模型,如下代码所示。此外,还将L1、L2以及混合损失记录在Tensorboard的writer中。
optimizer.zero_grad()
data, label = data.to(device), label.to(device)
output = model(data)
loss, loss_g1v, loss_g2v = criterion(output, label)
loss.backward()
optimizer.step()
2)评价模型
loss = evaluate(model, criterion, dataloader_valid, device, val_writer)
evaluate 函数的作用,计算模型的混合损失;此外,还将L1、L2以及混合损失记录在Tensorboard的writer中。
3)保存训练的模型
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'step': step,
'args': args}
# Save checkpoint per epoch
if loss < best_loss:
utils.save_on_master(
checkpoint,
os.path.join(args.output_path, 'checkpoint.pth'))
print('saving checkpoint at epoch: ', epoch)
chp = epoch
best_loss = loss
# Save checkpoint every epoch block
print('current best loss: ', best_loss)
print('current best epoch: ', chp)
if args.output_path and (epoch + 1) % args.epoch_block == 0:
utils.save_on_master(
checkpoint,
os.path.join(args.output_path, 'model_{}.pth'.format(epoch + 1))