Drawing Robust Scratch Tickets的code阅读笔记(第一次读code)(main.py——main_worker())

点开main.py文件

import os
import pathlib
import random
import time
import shutil
import math

from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed

from utils.conv_type import FixedSubnetConv, SampleSubnetConv
from utils.logging import AverageMeter, ProgressMeter
from utils.net_utils import (
    set_model_prune_rate,
    freeze_model_weights,
    freeze_model_subnet,
    save_checkpoint,
    get_lr,
    LabelSmoothing,
    init_model_weight_with_score,
)
from utils.schedulers import get_policy
import logging

from args import args
import importlib

import data
import models

from utils.builder import get_builder

入眼都是import什么的,大家知道这是大概是引入moudle的意思就行,想要详细了解的点击这里
然后再往下是一些def,就是一些函数定义,我们直接先翻到run code的地方。

if __name__ == "__main__":
    main()

就是run main()的意思,直接右键转到定义

def main():
    # print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # Simply call main_worker function
    main_worker(args)

上面大块的seed,可以得知大概是生成seed的用处,对于我们整个code的特点理解用途不大,直接跳过。
后面就是在run main_worker(args),在这里我右键看了下args的定义,在args.py里面,大概有进行参数的解析、帮助消息和误用参数时自动抛错的作用,也不算这个code的特点,这里就不细究了(详细了解点击这里)。直接进入main_worker(args)的定义:(很长,可以看出我们主要的看code任务就在这了,我们分块查看)

def main_worker(args):
    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
    args.ckpt_base_dir = ckpt_base_dir

    log = logging.getLogger(__name__)
    log_path = os.path.join(run_base_dir, 'log.txt')
    handlers = [logging.FileHandler(log_path, mode='a+'),
                logging.StreamHandler()]
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.INFO,
        handlers=handlers)
    log.info(args)
    
    if args.attack_type == 'free' and args.set == 'ImageNet':
        args.lr_policy = 'multistep_lr_imagenet_free'
        args.epochs = int(math.ceil(args.epochs / args.n_repeats))

    train, validate, validate_adv, modifier = get_trainer(args)
  • get_directories(args):进入函数查看后,得知主要是写入了run_base_dir、ckpt_base_dir、log_base_dir的地址。
  • math.ceil():返回大于或等于一个给定数字的最小整数。
  • get_trainer():作者自己定义的函数,有需可查看我的相关博客。

其他中间的很多有关logger的code,大概就是为了获得相关日志信息,不需要重点关注。

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)
  • get_model():作者自己定义的函数,这里是建立model的地方,应重点关注,有需可查看我的相关博客。
  • set_gpu():由名字可得是选择GPU的地方,无需重点关注。
    if args.task != 'search':
        if args.pretrained is None:
            path = run_base_dir.parent / 'search' / 'checkpoints'/ 'model_best.pth'
            if os.path.exists(path):
                args.pretrained = path
            else:
                path = run_base_dir.parent / 'checkpoints' / 'model_best.pth'
                if os.path.exists(path):
                    args.pretrained = path
                else:
                    print('No pretrained checkpoint:', path)
                    exit()
        pretrained(args, model)

    elif args.pretrained:
        pretrained(args, model)

以上一段code主要进行pretrain的工作。
以README.md(有需可查看我的相关博客)的To search an RST from a randomly initialized PreActResNet18 on CIFAR-10部分code为例。无pretrained的导入,查看args.py可得默认args.pretrained='None'args.task = 'search',故不参与pretrained。
若以README.md的To finetune the searched RST from PreActResNet18 on CIFAR-10 with inherited model weights为例。得args.pretrained='path-to-searched-rst'args.task = 'ft_inherit',将进入pretrained()(有需可查看我的相关博客)。

    # freezing the weights if we are only doing subnet training
    if args.task == 'search':
        freeze_model_weights(model)
    
    else:
        # freezing the subnet and finetuning the model weights
        freeze_model_subnet(model)

这里相关作用函数的名字已经明确表明了,读者也可以右键进去看看相关定义,不难理解。需要了解的:

  • model.named_modules():不但返回模型的所有子层,还会返回这些层的名字,点此详细了解。
  • hasattr(object, name):如果对象(object)有该属性(name)返回 True,否则返回 False。
        # finetune the robust ticket with random weight intialization
        if args.task == 'ft_reinit':
            args.init = args.ft_init
            builder = get_builder()
            for module in model.modules():
                if isinstance(module, nn.Conv2d):
                    builder._init_conv(module)

        # finetune the whole model with the initialization to be the robust ticket
        if args.task == 'ft_full':
            init_model_weight_with_score(model, prune_rate=args.prune_rate)
            # set_model_prune_rate(model, prune_rate=1.0)

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

以上主要是进行了一些微调的工作,有关get_builder()的分析可以,有需要的可查看我的相关博客。
后又进行了 optimizer、data、lr_policy、lossfunction等相关设置,右键进去,可得都是一些常用的code,具体内容,通过名称都可以大致得到了解。

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    natural_acc1_at_best_robustness = None

    if args.automatic_resume:
        args.resume = ckpt_base_dir / 'model_latest.pth'
        if os.path.isfile(args.resume):
            best_acc1, natural_acc1_at_best_robustness = resume(args, model, optimizer)
        else:
            print('Train from scratch.')

    elif args.resume:
        best_acc1, natural_acc1_at_best_robustness = resume(args, model, optimizer)

以上主要是从checkpoint选择进行resume的操作,不需要重点关注。

    # Data loading code
    if args.evaluate:
        if args.attack_type != 'None':
            acc1, acc5 = validate_adv(
                data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch
            )

            natural_acc1, natural_acc5 = validate(
                data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch
            )

            log.info('Natural Acc: %.2f, Robust Acc: %.2f', natural_acc1, acc1)
        
        else:
            acc1, acc5 = validate(
                data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch
            )

            log.info('Natural Acc: %.2f', acc1)

        return

以上主要是验证得到acc,主要有关函数在前面get_trainer()中已提到。

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(
        1, [epoch_time, validation_time, train_time], prefix="Overall Timing"
    )

主要是进行一些记录和更新,无需重点关注。

  • SummaryWriter():将条目直接写入 log_dir 中的事件文件以供 TensorBoard 使用。
  • AverageMeter():此处是管理一些变量的更新。
    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            'natural_acc1_at_best_robustness': natural_acc1_at_best_robustness,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    if args.discard_mode and args.progressive_prune:
        set_model_prune_rate(model, prune_rate=0.9)

同样类似于存储工作,记录初始状态。

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        modifier(args, epoch, model)

        cur_lr = get_lr(optimizer)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(
            data.train_loader, model, criterion, optimizer, epoch, args, writer=writer, log=log
        )

开始进行训练,主要的train()已经在上述get_trainer()中提出。

        if args.discard_mode:
            if (epoch+1) % args.discard_epoch == 0:
                for n, m in model.named_modules():
                    if hasattr(m, "discard_low_score"):
                        m.discard_low_score(min(args.discard_rate * ((epoch+1)//args.discard_epoch), 1))

        train_time.update((time.time() - start_train) / 60)

        # if 'ImageNet' in args.set:
        #     start_epoch = 30
        # else:
        #     start_epoch = 60

        # if args.optimizer == 'sgd':
        #     val_every = args.val_every if epoch > start_epoch else 10
        # else:
        #     val_every = args.val_every

        if epoch % args.val_every == 0 or epoch == args.epochs - 1:
            # evaluate on validation set
            start_validation = time.time()

            if args.attack_type != 'None':
                acc1, acc5 = validate_adv(data.val_loader, model, criterion, args, writer, epoch)
                natural_acc1, natural_acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)
            else:
                acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)

            validation_time.update((time.time() - start_validation) / 60)

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)
            best_acc5 = max(acc5, best_acc5)
            best_train_acc1 = max(train_acc1, best_train_acc1)
            best_train_acc5 = max(train_acc5, best_train_acc5)

            if is_best and args.attack_type != 'None':
                natural_acc1_at_best_robustness = natural_acc1
                
            if is_best:
                log.info(f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}")

            if is_best or epoch == args.epochs - 1:
                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "arch": args.arch,
                        "state_dict": model.state_dict(),
                        "best_acc1": best_acc1,
                        "best_acc5": best_acc5,
                        "best_train_acc1": best_train_acc1,
                        "best_train_acc5": best_train_acc5,
                        "natural_acc1_at_best_robustness": natural_acc1_at_best_robustness,
                        "optimizer": optimizer.state_dict(),
                        "curr_acc1": acc1,
                        "curr_acc5": acc5,
                    },
                    is_best,
                    filename=ckpt_base_dir / f"epoch_{epoch}.state",
                    save=False,
                )

            if args.attack_type != 'None':
                log.info('Epoch[%d][%d] curr natural acc: %.2f, natural acc at best robustness: %.2f \n curr robust acc: %.2f, best robust acc: %.2f', 
                        args.epochs, epoch, natural_acc1, natural_acc1_at_best_robustness, acc1, best_acc1)
            else:
                log.info('Epoch[%d][%d] curr acc: %.2f, best acc: %.2f', args.epochs, epoch, acc1, best_acc1)
        
        elif 'ImageNet' in args.set:
            save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "arch": args.arch,
                        "state_dict": model.state_dict(),
                        "best_acc1": best_acc1,
                        "best_acc5": best_acc5,
                        "best_train_acc1": best_train_acc1,
                        "best_train_acc5": best_train_acc5,
                        "natural_acc1_at_best_robustness": natural_acc1_at_best_robustness,
                        "optimizer": optimizer.state_dict(),
                        "curr_acc1": None,
                        "curr_acc5": None,
                    },
                    is_best=False,
                    filename=ckpt_base_dir / f"epoch_{epoch}.state",
                    save=False,
                )

        # if args.conv_type == "SampleSubnetConv":
        #     count = 0
        #     sum_pr = 0.0
        #     for n, m in model.named_modules():
        #         if isinstance(m, SampleSubnetConv):
        #             # avg pr across 10 samples
        #             pr = 0.0
        #             for _ in range(10):
        #                 pr += (
        #                     (torch.rand_like(m.clamped_scores) >= m.clamped_scores)
        #                     .float()
        #                     .mean()
        #                     .item()
        #                 )
        #             pr /= 10.0
        #             writer.add_scalar("pr/{}".format(n), pr, epoch)
        #             sum_pr += pr
        #             count += 1

        #     args.prune_rate = sum_pr / count
        #     writer.add_scalar("pr/average", args.prune_rate, epoch)

        if args.discard_mode:
            if (epoch+1) % args.discard_epoch == 0:
                if args.progressive_prune:
                    set_model_prune_rate(model, prune_rate=0.9-min(args.discard_rate * ((epoch+1)//args.discard_epoch), 1))

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(
            writer, prefix="diagnostics", global_step=epoch
        )

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    # write_result_to_csv(
    #     best_acc1=best_acc1,
    #     best_acc5=best_acc5,
    #     best_train_acc1=best_train_acc1,
    #     best_train_acc5=best_train_acc5,
    #     prune_rate=args.prune_rate,
    #     curr_acc1=acc1,
    #     curr_acc5=acc5,
    #     base_config=args.config,
    #     name=args.name,
    # )

    log_dir_new = 'logs/log_'+args.name
    if not os.path.exists(log_dir_new):
        os.makedirs(log_dir_new)
    
    shutil.copyfile(log_path, os.path.join(log_dir_new, 'log_'+args.task+'.txt'))

之后主要都是在记录、更新数据了,属于train之中。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值