2021SC@SDUSC
# Create trainer
trainer = train.Trainer(cfg, model, data_iter, optim.optim4GPU(cfg, model), get_device())
class Trainer(object):
"""Training Helper class"""
def __init__(self, cfg, model, data_iter, optimizer, device):
self.cfg = cfg
self.model = model
self.optimizer = optimizer
self.device = device
# data iter
if len(data_iter) == 1:
self.sup_iter = data_iter[0]
elif len(data_iter) == 2:
self.sup_iter = self.repeat_dataloader(data_iter[0])
self.unsup_iter = self.repeat_dataloader(data_iter[1])
elif len(data_iter) == 3:#uda
self.sup_iter = self.repeat_dataloader(data_iter[0])
self.unsup_iter = self.repeat_dataloader(data_iter[1])
self.eval_iter = data_iter[2]
def repeat_dataloader(self, iterable):
""" repeat dataloader """
while True:
for x in iterable:
yield x
yield解析:
一个带有 yield 的函数就是一个 generator,它和普通函数不同,生成一个 generator 看起来像函数调用,但不会执行任何函数代码,直到对其调用 next()(在 for 循环中会自动调用 next())才开始执行。虽然执行流程仍按函数的流程执行,但每执行到一个 yield 语句就会中断,并返回一个迭代值,下次执行时从 yield 的下一个语句继续执行。看起来就好像一个函数在正常执行的过程中被 yield 中断了数次,每次中断都会通过 yield 返回当前的迭代值。
yield 的好处是显而易见的,把一个函数改写为一个 generator 就获得了迭代能力,比起用类的实例保存状态来计算下一个 next() 的值,不仅代码简洁,而且执行流程异常清晰。
generator解析:
遵循迭代器(iterator)协议,迭代器协议需要实现__iter__、next接口
能过多次进入、多次返回,能够暂停函数体中代码的执行
所以调用repeat_dataloader时,返回的是一个生成器。
if cfg.mode == 'train_eval':
trainer.train(get_loss, get_acc, cfg.model_file, cfg.pretrain_file)
def train(self, get_loss, get_acc, model_file, pretrain_file):
""" train uda"""
# tensorboardX logging
if self.cfg.results_dir:
logger = SummaryWriter(log_dir=os.path.join(self.cfg.results_dir, 'logs'))
self.model.train()
self.load(model_file, pretrain_file) # between model_file and pretrain_file, only one model will be loaded
model = self.model.to(self.device)
if self.cfg.data_parallel: # Parallel GPU mode
model = nn.DataParallel(model)
global_step = 0
loss_sum = 0.
max_acc = [0., 0] # acc, step
# Progress bar is set by unsup or sup data
# uda_mode == True --> sup_iter is repeated
# uda_mode == False --> sup_iter is not repeated
iter_bar = tqdm(self.unsup_iter, total=self.cfg.total_steps) if self.cfg.uda_mode \
else tqdm(self.sup_iter, total=self.cfg.total_steps)
for i, batch in enumerate(iter_bar):
# Device assignment
if self.cfg.uda_mode:
sup_batch = [t.to(self.device) for t in next(self.sup_iter)]
unsup_batch = [t.to(self.device) for t in batch]
else:
sup_batch = [t.to(self.device) for t in batch]
unsup_batch = None
# update
self.optimizer.zero_grad()
final_loss, sup_loss, unsup_loss = get_loss(model, sup_batch, unsup_batch, global_step)
final_loss.backward()
self.optimizer.step()
# print loss
global_step += 1
loss_sum += final_loss.item()
if self.cfg.uda_mode:
iter_bar.set_description('final=%5.3f unsup=%5.3f sup=%5.3f'\
% (final_loss.item(), unsup_loss.item(), sup_loss.item()))
else:
iter_bar.set_description('loss=%5.3f' % (final_loss.item()))
# logging
if self.cfg.uda_mode:
logger.add_scalars('data/scalar_group',
{'final_loss': final_loss.item(),
'sup_loss': sup_loss.item(),
'unsup_loss': unsup_loss.item(),
'lr': self.optimizer.get_lr()[0]
}, global_step)
else:
logger.add_scalars('data/scalar_group',
{'sup_loss': final_loss.item()}, global_step)
if global_step % self.cfg.save_steps == 0:
self.save(global_step)
if get_acc and global_step % self.cfg.check_steps == 0 and global_step > 4999:
results = self.eval(get_acc, None, model)
total_accuracy = torch.cat(results).mean().item()
logger.add_scalars('data/scalar_group', {'eval_acc' : total_accuracy}, global_step)
if max_acc[0] < total_accuracy:
self.save(global_step)
max_acc = total_accuracy, global_step
print('Accuracy : %5.3f' % total_accuracy)
print('Max Accuracy : %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[1], global_step), end='\n\n')
if self.cfg.total_steps and self.cfg.total_steps < global_step:
print('The total steps have been reached')
print('Average Loss %5.3f' % (loss_sum/(i+1)))
if get_acc:
results = self.eval(get_acc, None, model)
total_accuracy = torch.cat(results).mean().item()
logger.add_scalars('data/scalar_group', {'eval_acc' : total_accuracy}, global_step)
if max_acc[0] < total_accuracy:
max_acc = total_accuracy, global_step
print('Accuracy :', total_accuracy)
print('Max Accuracy : %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[1], global_step), end='\n\n')
self.save(global_step)
return
return global_step
此段代码是对有标签和无标签数据的训练过程,使用梯度下降算法逐步优化训练模型,用到了tqdm库生成进度条,tensorboardX可视化深度学习指标。