Drawing Robust Scratch Tickets的code阅读笔记(第一次读code)(main.py——main_worker()——get_trainer())

进入get_trainer(),查看定义

def get_trainer(args):
    print(f"=> Using trainer from trainers.{args.trainer}")
    trainer = importlib.import_module(f"trainers.{args.trainer}")

    if args.attack_type == 'None':
        return trainer.train, trainer.validate, None, trainer.modifier
    else:
        if args.attack_type == 'free':
            return trainer.train_adv_free, trainer.validate, trainer.validate_adv, trainer.modifier
        else:
            return trainer.train_adv, trainer.validate, trainer.validate_adv, trainer.modifier
  • importlib.import_module(name, package=None):导入模块。name参数指定要以绝对或相对术语导入的模块。pkg.mod或…mod)。如果名称是以相对术语指定的,则包参数必须设置为包的名称,该名称将作为解析包名称的锚点。import_module(’… mod’, ‘pkg.subpkg’)将导入pkg.mod)。具体详情以及相关扩展,点击这里

由此可见,我们该进入trainer.{arg.trainer}(由args.py得此处默认值为trainer.default)查看后,便可根据不同的args.attack_type得到return不同函数。
下面就以trainer.default追踪到的default.py为例,查看主要训练评估过程。由于他们的整体构造都差不多,我们就主要看train()的部分

def train(train_loader, model, criterion, optimizer, epoch, args, writer, log):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.3f")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix=f"Epoch: [{epoch}]",
    )

以上主要是在进行参数存储和更新的管理。

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = len(train_loader)
    end = time.time()
    for i, (X, y) in tqdm.tqdm(
        enumerate(train_loader), ascii=True, total=len(train_loader)
    ):
        # measure data loading time
        data_time.update(time.time() - end)

        X = X.cuda()
        y = y.cuda()

tqdm: Python 进度条库,可以在 Python 长循环中添加一个进度提示信息。详情点击这里
X.cuda():将CPU上的Tensor或变量放到GPU上。

        # compute output
        output = model(X)

        loss = criterion(output, y)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, y, topk=(1, 5))
        losses.update(loss.item(), X.size(0))
        top1.update(acc1.item(), X.size(0))
        top5.update(acc5.item(), X.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            t = (num_batches * epoch + i) * batch_size
            progress.display(i)
            progress.write_to_tensorboard(writer, prefix="train", global_step=t)

    return top1.avg, top5.avg

后面就是进行相关acc、loss的计算并更新存储得到最终结果。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值