昨天说完了args的作用,今天就继续开这个大坑的重点内容:train_net
现在就先看看train_net整个函数是怎么运行的
def train_net(args):
torch.manual_seed(7)
np.random.seed(7)
checkpoint = args.checkpoint
start_epoch = 0
best_loss = float('inf')
writer = SummaryWriter()
epochs_since_improvement = 0
decays_since_improvement = 0
# Initialize / load checkpoint
if checkpoint is None:
model = DIMModel(n_classes=1, in_channels=4, is_unpooling=True, pretrain=True)
model = nn.DataParallel(model)
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom,
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
checkpoint = torch.load(checkpoint)
start_epoch = checkpoint['epoch'] + 1
epochs_since_improvement = checkpoint['epochs_since_improvement']
model = checkpoint['model'].module
optimizer = checkpoint['optimizer']
logger = get_logger()
# Move to GPU, if available
model = model.to(device)
# Custom dataloaders
train_dataset = DIMDataset('train')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
valid_dataset = DIMDataset('valid')
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
# Epochs
for epoch in range(start_epoch, args.end_epoch):
if args.optimizer == 'sgd' and epochs_since_improvement == 10:
break
if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
checkpoint = 'BEST_checkpoint.tar'
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
optimizer = checkpoint['optimizer']
decays_since_improvement += 1
print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))
adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)
# One epoch's training
train_loss = train(train_loader=train_loader,
model=model,
optimizer=optimizer,
epoch=epoch,
logger=logger)
effective_lr = get_learning_rate(optimizer)
print('Current effective learning rate: {}\n'.format(effective_lr))
writer.add_scalar('Train_Loss', train_loss, epoch)
writer.add_scalar('Learning_Rate', effective_lr, epoch)
# One epoch's validation
valid_loss = valid(valid_loader=valid_loader,
model=model,
logger=logger)
writer.add_scalar('Valid_Loss', valid_loss, epoch)
# Check if there was an improvement
is_best = valid_loss < best_loss
best_loss = min(valid_loss, best_loss)
if not is_best:
epochs_since_improvement += 1
print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
else:
epochs_since_improvement = 0
decays_since_improvement = 0
# Save checkpoint
save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
前面那些变量经过查找资料之后注释如下:
epochs_since_improvement和decays_since_improvement在后续的遍历有所体现,到后面再具体说是做什么的,end_epoch在上一讲中的argparse里面有所涉及,换句话说就类似于提前默认好了一个变量直接用,而且在参数修改的时候用命令行就能进行修改。
下一步会先去判断有没有检查点也就是事先练好的模型,如果没有的话就创建一个模型model,然后再判断优化器的类型来决定模型使用的优化器。由于篇幅关系DIM_MODEL这个今天就不做详解,整体的角度捋一遍整个train_net函数的结构。
这里面有一句dataparallel,这个是为了让模型能在多个gpu运行,因为用一个gpu跑dim的话显存有限而且时间太长,为了方便训练一般都使用多个显卡一起炼丹,笔者之前试过8块3090的效果,真的快如闪电。。。。
数据集dataloader的具体实现本期跳过,直接到后面的重点内容:训练过程
# Epochs
for epoch in range(start_epoch, args.end_epoch):
if args.optimizer == 'sgd' and epochs_since_improvement == 10:
break
if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
checkpoint = 'BEST_checkpoint.tar'
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
optimizer = checkpoint['optimizer']
decays_since_improvement += 1
print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))
adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)
# One epoch's training
train_loss = train(train_loader=train_loader,
model=model,
optimizer=optimizer,
epoch=epoch,
logger=logger)
effective_lr = get_learning_rate(optimizer)
print('Current effective learning rate: {}\n'.format(effective_lr))
writer.add_scalar('Train_Loss', train_loss, epoch)
writer.add_scalar('Learning_Rate', effective_lr, epoch)
# One epoch's validation
valid_loss = valid(valid_loader=valid_loader,
model=model,
logger=logger)
writer.add_scalar('Valid_Loss', valid_loss, epoch)
# Check if there was an improvement
is_best = valid_loss < best_loss
best_loss = min(valid_loss, best_loss)
if not is_best:
epochs_since_improvement += 1
print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
else:
epochs_since_improvement = 0
decays_since_improvement = 0
# Save checkpoint
save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
乍一看很简单,实际上就是各种套娃。。。初始的状态如下:
start_epoch=0
end_epoch = 100(可以命令行设定)
epochs_since_improvement = 0
decays_since_improvement = 0
从这里就必须要弄明白一个关键的事情:为什么要设置since_improvement这类的变量,上来判断说如果epochs_since_improvement==10的时候训练就停,这是必须要思考的问题。那就只盯着epochs_since_improvement和decays_since_improvement,直到整个代码的最后一块才找到问题的所在。
# One epoch's validation
valid_loss = valid(valid_loader=valid_loader,
model=model,
logger=logger)
writer.add_scalar('Valid_Loss', valid_loss, epoch)
# Check if there was an improvement
is_best = valid_loss < best_loss
best_loss = min(valid_loss, best_loss)
if not is_best:
epochs_since_improvement += 1
print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
else:
epochs_since_improvement = 0
decays_since_improvement = 0
# Save checkpoint
save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
之前说过,best_loss初始的时候是正无穷为了方便后续的损失也就是valid_loss进行更新,那么如果valid_loss小于best_loss,那么此时此刻best_loss更新为更小的数值,然后is_best会变成1,此时此刻epochs_since_improvement和decays_since_improvement就会更新为0,反之如果best_loss更小,那么epochs_since_improvement就会加1
再回到开头,如果epochs_since_improvement==10的时候就终止循环,也就是说这里面有10次的损失值没法更新了,那么这个变量的作用就体现出来:避免过多的训练导致资源的浪费,既然有连续十次的损失函数没法更新,那就没必要接着玩。
那么decays_since_improvement这个东西又是咋回事?往前面看看。
if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
checkpoint = 'BEST_checkpoint.tar'
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
optimizer = checkpoint['optimizer']
decays_since_improvement += 1
print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))
adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)
此时此刻在进入循环的时候如果之前出来的epochs_since_improvement大于0且能被2整除(正偶数),就直接在checkpoint里面进行运作。因为我从没用过他提供的checkpoint,而且正常运行的话损失函数在每一次的循环之后都会朝着更低的方向来跑,所以这里面我的猜测就是因为使用了已经训练好的模型的checkpoint,所以在训练的时候就会出现多次的最佳损失值无法更新,因此在运行的时候直接调用checkpiont里面的参数。后续再看很多代码的训练函数都有这么写的,因此这一功能就显得特别重要。
其实到了这块整个训练的代码结构的大概就已经展现在眼前了。在前期准备工作就绪之后,直接在每一个epoch里面进行模型循环然后再得到损失值进行更新,但是这里面的细节还得放到后面填坑
1.dim的模型结构
2.train和valid都是做什么的
3.writer到底有什么作用
这些放到后面填坑吧