点开main.py文件
import os
import pathlib
import random
import time
import shutil
import math
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from utils.conv_type import FixedSubnetConv, SampleSubnetConv
from utils.logging import AverageMeter, ProgressMeter
from utils.net_utils import (
set_model_prune_rate,
freeze_model_weights,
freeze_model_subnet,
save_checkpoint,
get_lr,
LabelSmoothing,
init_model_weight_with_score,
)
from utils.schedulers import get_policy
import logging
from args import args
import importlib
import data
import models
from utils.builder import get_builder
入眼都是import什么的,大家知道这是大概是引入moudle的意思就行,想要详细了解的点击这里。
然后再往下是一些def,就是一些函数定义,我们直接先翻到run code的地方。
if __name__ == "__main__":
main()
就是run main()
的意思,直接右键转到定义
def main():
# print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Simply call main_worker function
main_worker(args)
上面大块的seed,可以得知大概是生成seed的用处,对于我们整个code的特点理解用途不大,直接跳过。
后面就是在run main_worker(args)
,在这里我右键看了下args的定义,在args.py里面,大概有进行参数的解析、帮助消息和误用参数时自动抛错的作用,也不算这个code的特点,这里就不细究了(详细了解点击这里)。直接进入main_worker(args)
的定义:(很长,可以看出我们主要的看code任务就在这了,我们分块查看)
def main_worker(args):
# Set up directories
run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
args.ckpt_base_dir = ckpt_base_dir
log = logging.getLogger(__name__)
log_path = os.path.join(run_base_dir, 'log.txt')
handlers = [logging.FileHandler(log_path, mode='a+'),
logging.StreamHandler()]
logging.basicConfig(
format='[%(asctime)s] - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO,
handlers=handlers)
log.info(args)
if args.attack_type == 'free' and args.set == 'ImageNet':
args.lr_policy = 'multistep_lr_imagenet_free'
args.epochs = int(math.ceil(args.epochs / args.n_repeats))
train, validate, validate_adv, modifier = get_trainer(args)
get_directories(args)
:进入函数查看后,得知主要是写入了run_base_dir、ckpt_base_dir、log_base_dir的地址。math.ceil()
:返回大于或等于一个给定数字的最小整数。get_trainer()
:作者自己定义的函数,有需可查看我的相关博客。
其他中间的很多有关logger的code,大概就是为了获得相关日志信息,不需要重点关注。
# create model and optimizer
model = get_model(args)
model = set_gpu(args, model)
get_model()
:作者自己定义的函数,这里是建立model的地方,应重点关注,有需可查看我的相关博客。set_gpu()
:由名字可得是选择GPU的地方,无需重点关注。
if args.task != 'search':
if args.pretrained is None:
path = run_base_dir.parent / 'search' / 'checkpoints'/ 'model_best.pth'
if os.path.exists(path):
args.pretrained = path
else:
path = run_base_dir.parent / 'checkpoints' / 'model_best.pth'
if os.path.exists(path):
args.pretrained = path
else:
print('No pretrained checkpoint:', path)
exit()
pretrained(args, model)
elif args.pretrained:
pretrained(args, model)
以上一段code主要进行pretrain的工作。
以README.md(有需可查看我的相关博客)的To search an RST from a randomly initialized PreActResNet18 on CIFAR-10部分code为例。无pretrained的导入,查看args.py可得默认args.pretrained='None'
且args.task = 'search'
,故不参与pretrained。
若以README.md的To finetune the searched RST from PreActResNet18 on CIFAR-10 with inherited model weights为例。得args.pretrained='path-to-searched-rst'
且args.task = 'ft_inherit'
,将进入pretrained()
(有需可查看我的相关博客)。
# freezing the weights if we are only doing subnet training
if args.task == 'search':
freeze_model_weights(model)
else:
# freezing the subnet and finetuning the model weights
freeze_model_subnet(model)
这里相关作用函数的名字已经明确表明了,读者也可以右键进去看看相关定义,不难理解。需要了解的:
model.named_modules()
:不但返回模型的所有子层,还会返回这些层的名字,点此详细了解。hasattr(object, name)
:如果对象(object)有该属性(name)返回 True,否则返回 False。
# finetune the robust ticket with random weight intialization
if args.task == 'ft_reinit':
args.init = args.ft_init
builder = get_builder()
for module in model.modules():
if isinstance(module, nn.Conv2d):
builder._init_conv(module)
# finetune the whole model with the initialization to be the robust ticket
if args.task == 'ft_full':
init_model_weight_with_score(model, prune_rate=args.prune_rate)
# set_model_prune_rate(model, prune_rate=1.0)
optimizer = get_optimizer(args, model)
data = get_dataset(args)
lr_policy = get_policy(args.lr_policy)(optimizer, args)
if args.label_smoothing is None:
criterion = nn.CrossEntropyLoss().cuda()
else:
criterion = LabelSmoothing(smoothing=args.label_smoothing)
以上主要是进行了一些微调的工作,有关get_builder()
的分析可以,有需要的可查看我的相关博客。
后又进行了 optimizer、data、lr_policy、lossfunction等相关设置,右键进去,可得都是一些常用的code,具体内容,通过名称都可以大致得到了解。
# optionally resume from a checkpoint
best_acc1 = 0.0
best_acc5 = 0.0
best_train_acc1 = 0.0
best_train_acc5 = 0.0
natural_acc1_at_best_robustness = None
if args.automatic_resume:
args.resume = ckpt_base_dir / 'model_latest.pth'
if os.path.isfile(args.resume):
best_acc1, natural_acc1_at_best_robustness = resume(args, model, optimizer)
else:
print('Train from scratch.')
elif args.resume:
best_acc1, natural_acc1_at_best_robustness = resume(args, model, optimizer)
以上主要是从checkpoint选择进行resume的操作,不需要重点关注。
# Data loading code
if args.evaluate:
if args.attack_type != 'None':
acc1, acc5 = validate_adv(
data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch
)
natural_acc1, natural_acc5 = validate(
data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch
)
log.info('Natural Acc: %.2f, Robust Acc: %.2f', natural_acc1, acc1)
else:
acc1, acc5 = validate(
data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch
)
log.info('Natural Acc: %.2f', acc1)
return
以上主要是验证得到acc,主要有关函数在前面get_trainer()
中已提到。
writer = SummaryWriter(log_dir=log_base_dir)
epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
train_time = AverageMeter("train_time", ":.4f", write_avg=False)
progress_overall = ProgressMeter(
1, [epoch_time, validation_time, train_time], prefix="Overall Timing"
)
主要是进行一些记录和更新,无需重点关注。
SummaryWriter()
:将条目直接写入 log_dir 中的事件文件以供 TensorBoard 使用。AverageMeter()
:此处是管理一些变量的更新。
end_epoch = time.time()
args.start_epoch = args.start_epoch or 0
acc1 = None
# Save the initial state
save_checkpoint(
{
"epoch": 0,
"arch": args.arch,
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"best_acc5": best_acc5,
"best_train_acc1": best_train_acc1,
"best_train_acc5": best_train_acc5,
'natural_acc1_at_best_robustness': natural_acc1_at_best_robustness,
"optimizer": optimizer.state_dict(),
"curr_acc1": acc1 if acc1 else "Not evaluated",
},
False,
filename=ckpt_base_dir / f"initial.state",
save=False,
)
if args.discard_mode and args.progressive_prune:
set_model_prune_rate(model, prune_rate=0.9)
同样类似于存储工作,记录初始状态。
# Start training
for epoch in range(args.start_epoch, args.epochs):
lr_policy(epoch, iteration=None)
modifier(args, epoch, model)
cur_lr = get_lr(optimizer)
# train for one epoch
start_train = time.time()
train_acc1, train_acc5 = train(
data.train_loader, model, criterion, optimizer, epoch, args, writer=writer, log=log
)
开始进行训练,主要的train()
已经在上述get_trainer()
中提出。
if args.discard_mode:
if (epoch+1) % args.discard_epoch == 0:
for n, m in model.named_modules():
if hasattr(m, "discard_low_score"):
m.discard_low_score(min(args.discard_rate * ((epoch+1)//args.discard_epoch), 1))
train_time.update((time.time() - start_train) / 60)
# if 'ImageNet' in args.set:
# start_epoch = 30
# else:
# start_epoch = 60
# if args.optimizer == 'sgd':
# val_every = args.val_every if epoch > start_epoch else 10
# else:
# val_every = args.val_every
if epoch % args.val_every == 0 or epoch == args.epochs - 1:
# evaluate on validation set
start_validation = time.time()
if args.attack_type != 'None':
acc1, acc5 = validate_adv(data.val_loader, model, criterion, args, writer, epoch)
natural_acc1, natural_acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)
else:
acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)
validation_time.update((time.time() - start_validation) / 60)
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
best_acc5 = max(acc5, best_acc5)
best_train_acc1 = max(train_acc1, best_train_acc1)
best_train_acc5 = max(train_acc5, best_train_acc5)
if is_best and args.attack_type != 'None':
natural_acc1_at_best_robustness = natural_acc1
if is_best:
log.info(f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}")
if is_best or epoch == args.epochs - 1:
save_checkpoint(
{
"epoch": epoch + 1,
"arch": args.arch,
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"best_acc5": best_acc5,
"best_train_acc1": best_train_acc1,
"best_train_acc5": best_train_acc5,
"natural_acc1_at_best_robustness": natural_acc1_at_best_robustness,
"optimizer": optimizer.state_dict(),
"curr_acc1": acc1,
"curr_acc5": acc5,
},
is_best,
filename=ckpt_base_dir / f"epoch_{epoch}.state",
save=False,
)
if args.attack_type != 'None':
log.info('Epoch[%d][%d] curr natural acc: %.2f, natural acc at best robustness: %.2f \n curr robust acc: %.2f, best robust acc: %.2f',
args.epochs, epoch, natural_acc1, natural_acc1_at_best_robustness, acc1, best_acc1)
else:
log.info('Epoch[%d][%d] curr acc: %.2f, best acc: %.2f', args.epochs, epoch, acc1, best_acc1)
elif 'ImageNet' in args.set:
save_checkpoint(
{
"epoch": epoch + 1,
"arch": args.arch,
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"best_acc5": best_acc5,
"best_train_acc1": best_train_acc1,
"best_train_acc5": best_train_acc5,
"natural_acc1_at_best_robustness": natural_acc1_at_best_robustness,
"optimizer": optimizer.state_dict(),
"curr_acc1": None,
"curr_acc5": None,
},
is_best=False,
filename=ckpt_base_dir / f"epoch_{epoch}.state",
save=False,
)
# if args.conv_type == "SampleSubnetConv":
# count = 0
# sum_pr = 0.0
# for n, m in model.named_modules():
# if isinstance(m, SampleSubnetConv):
# # avg pr across 10 samples
# pr = 0.0
# for _ in range(10):
# pr += (
# (torch.rand_like(m.clamped_scores) >= m.clamped_scores)
# .float()
# .mean()
# .item()
# )
# pr /= 10.0
# writer.add_scalar("pr/{}".format(n), pr, epoch)
# sum_pr += pr
# count += 1
# args.prune_rate = sum_pr / count
# writer.add_scalar("pr/average", args.prune_rate, epoch)
if args.discard_mode:
if (epoch+1) % args.discard_epoch == 0:
if args.progressive_prune:
set_model_prune_rate(model, prune_rate=0.9-min(args.discard_rate * ((epoch+1)//args.discard_epoch), 1))
epoch_time.update((time.time() - end_epoch) / 60)
progress_overall.display(epoch)
progress_overall.write_to_tensorboard(
writer, prefix="diagnostics", global_step=epoch
)
writer.add_scalar("test/lr", cur_lr, epoch)
end_epoch = time.time()
# write_result_to_csv(
# best_acc1=best_acc1,
# best_acc5=best_acc5,
# best_train_acc1=best_train_acc1,
# best_train_acc5=best_train_acc5,
# prune_rate=args.prune_rate,
# curr_acc1=acc1,
# curr_acc5=acc5,
# base_config=args.config,
# name=args.name,
# )
log_dir_new = 'logs/log_'+args.name
if not os.path.exists(log_dir_new):
os.makedirs(log_dir_new)
shutil.copyfile(log_path, os.path.join(log_dir_new, 'log_'+args.task+'.txt'))
之后主要都是在记录、更新数据了,属于train之中。