目录
def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'):
创建了用于记录训练过程中各种指标的AverageMeter对象
使用教师模型和学生模型分别对总的输入inputs_all进行前向传播
def get_freer_gpu():
def get_freer_gpu():
os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
memory_available = [int(3944), int(4095)]
# memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
return np.argmax(memory_available)
os.environ["CUDA_VISIBLE_DEVICES"]=str(get_freer_gpu())
这段代码的作用是获取当前可用的显存最大的 GPU 设备,并将该设备的索引号设置为 CUDA 可见设备的环境变量。
首先定义了一个名为 get_freer_gpu()
的函数。该函数通过运行系统命令 nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp
,获取全部 GPU 设备可用显存大小,并将输出结果保存到名为 "tmp" 的文件中。接着,在内存可用数组中定义了两个整数,分别表示两个 GPU 设备的可用显存大小,这里提前定义的原因可能是部分 GPU 设备的可用显存大小是固定的,而无法通过命令获取。
然后执行 np.argmax(memory_available)
函数来获取当前可用显存最大的 GPU 设备的索引号,并返回该索引号,这个索引号表示在系统中GPU设备的编号。(这看起来不像设备索引号,像最大可用内存空间大小啊?)
最后,通过将该索引号转化为字符串并将其赋值给环境变量 "CUDA_VISIBLE_DEVICES"
,可以将该 GPU 设备设置为 CUDA 可见设备,让程序在运行时只使用该 GPU 设备进行计算,以充分利用 GPU 的计算加速能力。
class Wrapper(nn.Module):
class Wrapper(nn.Module):
def __init__(self, model, args):
super(Wrapper, self).__init__()
self.model = model # 看C++语法,这里已经忘了
self.feat = torch.nn.Sequential(*list(self.model.children())[:-2])
self.last = torch.nn.Linear(list(self.model.children())[-2].in_features, 64)
def forward(self, images):
feat = self.feat(images)
feat = feat.view(images.size(0), -1)
out = self.last(feat)
return feat, out
这个类名为 Wrapper
,它是一个继承自 nn.Module
的子类,用于包装神经网络模型。
在 Wrapper
类的构造函数 __init__()
中,接收两个参数 model
和 args
。model
是一个已经定义好的神经网络模型,args
则可能是其他配置参数。
构造函数中的 super(Wrapper, self).__init__()
语句是调用父类 nn.Module
的构造函数进行初始化。# 这个没看懂,得复习一下C++语法
接下来,通过 self.model = model
将传入的模型赋值给 Wrapper
类的成员变量 self.model
。
然后,使用 torch.nn.Sequential()
构建了一个新的神经网络层序列 self.feat
,该序列由 self.model
的所有子模型(即所有孩子节点)去掉最后两个子模型组成。可以理解为 self.feat
是 self.model
倒数第三个子模型的输出。
再之后,使用 torch.nn.Linear()
构造了一个线性层 self.last
,该层的输入大小为 self.model
倒数第二子模型的输入特征大小,输出大小为 64。
forward()
方法定义了前向传播过程。给定输入 images
,首先将其通过 self.feat
进行特征提取,得到特征 feat
。然后将 feat
进行展平操作,正好可以将 images
的大小为 (batch_size, channels, height, width)
的张量转换成了 (batch_size, num_features)
。接着,将展平后的特征 feat
作为输入传递给线性层 self.last
,得到输出 out
。最后,返回特征 feat
和输出 out
。
通过使用这个 Wrapper
类,可以方便地访问封装模型的不同层输出,尤其是在需要获取中间层特征时很有用。
def parse_option():
def parse_option():
parser = argparse.ArgumentParser('argument for training')
parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency')
parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
# optimization
parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
# dataset and model
parser.add_argument('--model_s', type=str, default='resnet12', choices=model_pool)
parser.add_argument('--model_t', type=str, default='resnet12', choices=model_pool)
parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
'CIFAR-FS', 'FC100'])
parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning')
parser.add_argument('--tags', type=str, default="gen1, ssl", help='add tags for the experiment')
parser.add_argument('--transform', type=str, default='A', choices=transforms_list)
# path to teacher model
parser.add_argument('--path_t', type=str, default="", help='teacher model snapshot')
# distillation
parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'contrast', 'hint', 'attention'])
parser.add_argument('--trial', type=str, default='1', help='trial id')
parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for KD')
parser.add_argument('-b', '--beta', type=float, default=0, help='weight balance for other losses')
# KL distillation
parser.add_argument('--kd_T', type=float, default=2, help='temperature for KD distillation')
# NCE distillation
parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')
# cosine annealing
parser.add_argument('--cosine', action='store_true', help='using cosine annealing')
# specify folder
parser.add_argument('--model_path', type=str, default='save/', help='path to save model')
parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard')
parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')
# setting for meta-learning
parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
help='Number of test runs')
parser.add_argument('--n_ways', type=int, default=5, metavar='N',
help='Number of classes for doing each classification run')
parser.add_argument('--n_shots', type=int, default=1, metavar='N',
help='Number of shots in test')
parser.add_argument('--n_queries', type=int, default=15, metavar='N',
help='Number of query in test')
parser.add_argument('--n_aug_support_samples', default=5, type=int,
help='The number of augmented samples for each meta test sample')
parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
help='Size of test batch)')
opt = parser.parse_args()
if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
opt.transform = 'D'
if 'trainval' in opt.path_t:
opt.use_trainval = True
else:
opt.use_trainval = False
if opt.use_trainval:
opt.trial = opt.trial + '_trainval'
# set the path according to the environment
if not opt.model_path:
opt.model_path = './models_distilled'
if not opt.tb_path:
opt.tb_path = './tensorboard'
if not opt.data_root:
opt.data_root = './data/{}'.format(opt.dataset)
else:
opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
opt.data_aug = True
tags = opt.tags.split(',')
opt.tags = list([])
for it in tags:
opt.tags.append(it)
iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
opt.lr_decay_epochs.append(int(it))
opt.model_name = 'S:{}_T:{}_{}_{}_r:{}_a:{}_b:{}_trans_{}'.format(opt.model_s, opt.model_t, opt.dataset,
opt.distill, opt.gamma, opt.alpha, opt.beta,
opt.transform)
if opt.cosine:
opt.model_name = '{}_cosine'.format(opt.model_name)
opt.model_name = '{}_{}'.format(opt.model_name, opt.trial)
opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
if not os.path.isdir(opt.tb_folder):
os.makedirs(opt.tb_folder)
opt.save_folder = os.path.join(opt.model_path, opt.model_name)
if not os.path.isdir(opt.save_folder):
os.makedirs(opt.save_folder)
#extras
opt.fresh_start = True
return opt
上述代码是一个用于解析命令行参数的函数。该函数使用argparse库来解析参数,并返回一个包含解析结果的对象opt。
解析器定义了一系列的参数,如eval_freq、print_freq、tb_freq等。每个参数都有自己的类型、默认值和帮助信息。
在函数内部,使用parser.parse_args()方法对命令行参数进行解析,并将解析结果保存在opt对象中。
随后,根据一些特殊的逻辑对opt的某些属性进行了一些调整和赋值。例如,如果数据集是'CIFAR-FS'或'FC100',则将opt.transform设置为'D';如果opt.path_t中包含'trainval',则将opt.use_trainval设置为True。
最后,根据一些规则设置了模型名称、存储路径等其他属性,并返回了opt对象。
需要注意的是,上述代码只是对解析器的定义和一些参数属性的初始化操作,并没有执行实际的参数解析。实际的参数解析需要在调用parse_option()函数时进行。
parser
def parse_option():
parser = argparse.ArgumentParser('argument for training')
parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency')
parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
# optimization
parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
# dataset and model
parser.add_argument('--model_s', type=str, default='resnet12', choices=model_pool)
parser.add_argument('--model_t', type=str, default='resnet12', choices=model_pool)
parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
'CIFAR-FS', 'FC100'])
parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning')
parser.add_argument('--tags', type=str, default="gen1, ssl", help='add tags for the experiment')
parser.add_argument('--transform', type=str, default='A', choices=transforms_list)
# path to teacher model
parser.add_argument('--path_t', type=str, default="", help='teacher model snapshot')
# distillation
parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'contrast', 'hint', 'attention'])
parser.add_argument('--trial', type=str, default='1', help='trial id')
parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for KD')
parser.add_argument('-b', '--beta', type=float, default=0, help='weight balance for other losses')
# KL distillation
parser.add_argument('--kd_T', type=float, default=2, help='temperature for KD distillation')
# NCE distillation
parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')
# cosine annealing
parser.add_argument('--cosine', action='store_true', help='using cosine annealing')
# specify folder
parser.add_argument('--model_path', type=str, default='save/', help='path to save model')
parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard')
parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')
# setting for meta-learning
parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
help='Number of test runs')
parser.add_argument('--n_ways', type=int, default=5, metavar='N',
help='Number of classes for doing each classification run')
parser.add_argument('--n_shots', type=int, default=1, metavar='N',
help='Number of shots in test')
parser.add_argument('--n_queries', type=int, default=15, metavar='N',
help='Number of query in test')
parser.add_argument('--n_aug_support_samples', default=5, type=int,
help='The number of augmented samples for each meta test sample')
parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
help='Size of test batch)')
定义了许多参数。
这些参数可以从命令行接收,也可以在代码中进行硬编码或通过其他方式进行设置。根据 parser.add_argument()
的用法,这些参数可以在命令行中使用 --参数名 参数值
的形式进行设置,例如 --model_path save/
。
除了命令行参数外,还可以通过其他方式来设置这些参数的默认值。例如,在代码中直接为这些参数赋予默认值,如 default='save/'
、default='tb/'
和 default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/'
。这样在运行代码时,如果没有通过命令行传递这些参数,则会使用默认值。
另外,你还可以通过其他方式传递这些参数,比如从配置文件中读取,或者通过函数调用时传递参数值。总之,argparse.ArgumentParser
可以用于解析命令行参数,但它并不限制参数只能从命令行接收,你可以根据自己的需求选择不同的方式来设置这些参数的值。
opt = parser.parse_args()
opt = parser.parse_args()
if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
opt.transform = 'D'
if 'trainval' in opt.path_t:
opt.use_trainval = True
else:
opt.use_trainval = False
if opt.use_trainval:
opt.trial = opt.trial + '_trainval'
# set the path according to the environment
if not opt.model_path:
opt.model_path = './models_distilled'
if not opt.tb_path:
opt.tb_path = './tensorboard'
if not opt.data_root:
opt.data_root = './data/{}'.format(opt.dataset)
else:
opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
opt.data_aug = True
这段代码中首先通过 parser.parse_args()
解析命令行参数,将所需的参数值存储在一个名为 opt
的对象中。
接下来,通过判断 opt.dataset
是否为 'CIFAR-FS'
或 'FC100'
来设置 opt.transform
为 'D'
。这里的作用是区分数据集不同,因为对于 'CIFAR-FS'
和 'FC100'
数据集来说,数据增强的方式不同,'D'
方式适用于这两个数据集。
然后,通过判断 opt.path_t
中是否包含 'trainval'
来设置 opt.use_trainval
为 True
或 False
。如果包括,则说明使用包含训练和验证集的路径(即在训练过程中使用验证集),否则只使用单独的训练集。
在opt.use_trainval=True
的情况下,将 opt.trial
加上 _trainval
的后缀,以表示此时的 trial 是使用训练和验证集的。
然后,通过检查 opt.model_path
、opt.tb_path
和 opt.data_root
是否已经定义,如果没有定义,则分别赋予默认值 './models_distilled'
、'./tensorboard'
和 ./data/{opt.dataset}
。如果已经定义,则根据已有的 opt.dataset
将其设置为合适的路径。
最后,设置 opt.data_aug
为 True
,指定对数据使用数据增强的方式。
def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'):
def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'):
"""load the teacher model"""
print('==> loading teacher model')
print(model_name)
model = create_model(model_name, n_cls, dataset)
model.load_state_dict(torch.load(model_path)['model'])
print('==> done')
return model
上述代码是一个加载教师模型的函数。函数接受以下参数:
- model_path:教师模型的路径。
- model_name:教师模型的名称。
- n_cls:分类任务的类别数。
- dataset:使用的数据集,默认为'miniImageNet'。
函数首先打印一条信息,表示正在加载教师模型,并输出模型名称。
然后,通过调用create_model()函数创建一个模型实例,传入模型名称和分类任务的类别数。create_model()函数根据模型名称和类别数选择合适的模型,并返回模型实例。
接下来,使用torch.load()函数加载保存在model_path中的教师模型的权重,并将权重加载到模型实例中的state_dict中。
最后,函数打印一条信息,表示模型加载完成,并返回加载后的模型实例。
需要注意的是,上述代码假设教师模型的权重是以字典形式保存在model_path中的,并且字典中的键为'model'。如果教师模型的权重保存方式不同,需要进行相应的修改。
def main():
调用opt,解析命令行参数并使用
def main():
best_acc = 0
opt = parse_option()
wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags)
wandb.config.update(opt)
wandb.save('*.py')
wandb.run.save()
best_acc
:用于记录最佳的准确率。opt = parse_option()
:调用parse_option()
函数,解析命令行参数并返回一个配置对象。前面只是定义命令行参数,到这里真的进行命令行参数解析了)wandb.init()
:初始化Weights & Biases,将模型路径和标签作为项目名称和标签进行设置。wandb.config.update(opt)
:更新Weights & Biases的配置,将配置对象传递给它。wandb.save('*.py')
:保存代码文件到Weights & Biases的运行目录。wandb.run.save()
:保存Weights & Biases的运行状态。
# dataloader
train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt)
获取训练集、验证集和元测试集的数据加载器
get_dataloaders(opt)
:根据配置对象opt获取训练集、验证集和元测试集的数据加载器,并返回它们以及类别数量n_cls。
教师模型的加载
# model
model_t = []
if("," in opt.path_t):
for path in opt.path_t.split(","):
model_t.append(load_teacher(path, opt.model_t, n_cls, opt.dataset))
else:
model_t.append(load_teacher(opt.path_t, opt.model_t, n_cls, opt.dataset))
- 创建一个空列表
model_t
用于存储教师模型。 - 如果
opt.path_t
中包含逗号,则按逗号分割路径并加载多个教师模型,然后将它们添加到model_t
列表中。 - 否则,只加载一个教师模型并将其添加到
model_t
列表中。 - load_teacher是之前定义的一个函数。
创建学生模型
model_s = copy.deepcopy(model_t[0])
- 创建一个学生模型,并使用
copy.deepcopy()
函数深度复制model_t
列表中的第一个教师模型。
交叉熵损失函数、知识蒸馏损失函
数、随机梯度下降优化器
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = DistillKL(opt.kd_T)
optimizer = optim.SGD(model_s.parameters(),
lr=opt.learning_rate,
momentum=opt.momentum,
weight_decay=opt.weight_decay)
- 创建交叉熵损失函数
criterion_cls
。 - 基于配置对象中的温度参数
kd_T
创建教师-学生之间的知识蒸馏损失函数criterion_div
和criterion_kd
。 - 创建随机梯度下降(SGD)优化器,并设置学习率、动量和权重衰减。
是否有可用的GPU设备及模型迁移到GPU
if torch.cuda.is_available():
for m in model_t:
m.cuda()
model_s.cuda()
criterion_cls = criterion_cls.cuda()
criterion_div = criterion_div.cuda()
criterion_kd = criterion_kd.cuda()
cudnn.benchmark = True
- 检查是否有可用的GPU设备。
- 如果是,将教师模型和学生模型移动到GPU上。
- 将损失函数也移动到GPU上。
- 启用cudnn加速。
准确率、标准差的定义,调整学习率
meta_test_acc = 0
meta_test_std = 0
- 初始化元测试集的准确率和标准差为0。
for epoch in range(1, opt.epochs + 1):
if opt.cosine:
scheduler.step()
else:
adjust_learning_rate(epoch, opt, optimizer)
print("==> training...")
time1 = time.time()
train_acc, train_loss = train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt)
time2 = time.time()
print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
- 根据配置对象中的
cosine
参数选择调整学习率的策略:如果为True,则使用余弦退火调整学习率;否则,调用adjust_learning_rate()
函数根据当前epoch来调整学习率。 - 打印训练开始的提示信息。
- 记录训练开始的时间。
- 调用
train()
函数进行模型训练,并获取训练集的准确率和损失。 - 记录训练结束的时间,并打印本轮训练的总时间。
验证集准确率,与评估结果
val_acc = 0
val_loss = 0
meta_val_acc = 0
meta_val_std = 0
- 初始化验证集准确率和损失、元验证集准确率和标准差为0。
start = time.time()
meta_test_acc, meta_test_std = meta_test(model_s, meta_testloader, use_logit=False)
test_time = time.time() - start
print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.format(meta_test_acc, meta_test_std, test_time))
- 记录当前时间。
- 调用
meta_test()
函数对学生模型在元测试集上进行评估,并获取元测试集的准确率和标准差。 - 计算评估过程的时间并打印结果。
保存模型
if epoch % opt.save_freq == 0 or epoch==opt.epochs:
print('==> Saving...')
state = {
'epoch': epoch,
'model': model_s.state_dict(),
}
save_file = os.path.join(opt.save_folder, 'model_'+str(wandb.run.name)+'.pth')
torch.save(state, save_file)
#wandb saving
torch.save(state, os.path.join(wandb.run.dir, "model.pth"))
- 如果当前epoch是保存频率的倍数,或者是最后一个epoch,则保存模型。
- 创建一个状态字典,包含当前epoch和学生模型的状态字典。
- 构建保存文件的路径,并保存模型。
- 还将模型保存到Weights & Biases的运行目录下。
wandb.log({'epoch': epoch,
'Train Acc': train_acc,
'Train Loss':train_loss,
'Val Acc': val_acc,
'Val Loss':val_loss,
'Meta Test Acc': meta_test_acc,
'Meta Test std': meta_test_std,
'Meta Val Acc': meta_val_acc,
'Meta Val std': meta_val_std
})
使用Weights & Biases的wandb.log()
函数记录各种指标和损失。
generate_final_report(model_s, opt, wandb)
调用generate_final_report()
函数生成最终报告。
output_log_file = os.path.join(wandb.run.dir, "output.log")
if os.path.isfile(output_log_file):
os.remove(output_log_file)
else:
print("Error: %s file not found" % output_log_file)
- 设置日志文件路径。
- 如果存在该文件,则删除日志文件。
- 否则,打印错误消息。
def train
def train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt):
"""One epoch training"""
- 定义了一个名为
train()
的函数,接受epoch数、训练集数据加载器、学生模型、教师模型、分类标准损失函数、蒸馏损失函数、知识蒸馏损失函数、优化器和配置对象作为参数。 - 函数的注释指出该函数用于进行一个epoch的训练。
将学生模型设置为训练模式,教师模型设置为评估模式
model_s.train()
for m in model_t:
m.eval()
- 将学生模型设置为训练模式,通过调用
model_s.train()
。 - 将所有的教师模型设置为评估模式,通过循环遍历所有的教师模型并调用
model.eval()
。
创建了用于记录训练过程中各种指标的AverageMeter对象
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
创建了用于记录训练过程中各种指标的AverageMeter对象:batch_time
(记录每批数据的耗时)、data_time
(记录数据加载的耗时)、losses
(记录训练损失)、top1
(记录top1准确率)、top5
(记录top5准确率)。
end = time.time()
with tqdm(train_loader, total=len(train_loader)) as pbar:
for idx, data in enumerate(pbar):
- 记录当前时间。
- 使用
tqdm
库创建了一个进度条,并迭代训练集数据加载器。
从数据中获取输入和目标
inputs, targets, _ = data
data_time.update(time.time() - end)
- 从数据中获取输入和目标。
- 更新数据加载的耗时,通过计算当前时间与上一步记录的时间差。
inputs = inputs.float()
if torch.cuda.is_available():
inputs = inputs.cuda()
targets = targets.cuda()
- 将输入和目标转换为浮点型数据。
- 如果有可用的CUDA设备,将输入和目标移动到CUDA设备上。
对输入进行旋转增强
batch_size = inputs.size()[0]
x = inputs
x_90 = x.transpose(2,3).flip(2)
x_180 = x.flip(2).flip(3)
x_270 = x.flip(2).transpose(2,3)
inputs_aug = torch.cat((x_90, x_180, x_270),0)
sampled_inputs = inputs_aug[torch.randperm(3*batch_size)[:batch_size]]
inputs_all = torch.cat((x, x_180, x_90, x_270),0)
- 获取批次大小。
- 对输入进行旋转增强:分别将输入旋转90度、180度和270度,并进行拼接,形成增强后的输入
inputs_aug
。 - 通过随机采样从增强输入中选择batch_size个样本,保存在
sampled_inputs
中。 - 将原始输入和所有增强后的输入进行拼接,形成总的输入
inputs_all
。
使用教师模型和学生模型分别对总的输入inputs_all
进行前向传播
with torch.no_grad():
(_,_,_,_, feat_t), (logit_t, rot_t) = model_t[0](inputs_all[:batch_size], rot=True)
(_,_,_,_, feat_s_all), (logit_s_all, rot_s_all) = model_s(inputs_all[:4*batch_size], rot=True)
loss_div = criterion_div(logit_s_all[:batch_size], logit_t[:batch_size])
d_90 = logit_s_all[batch_size:2*batch_size] - logit_s_all[:batch_size]
loss_a = torch.mean(torch.sqrt(torch.sum((d_90)**2, dim=1)))
# d_180 = logit_s_all[2*batch_size:3*batch_size] - logit_s_all[:batch_size]
# loss_a += torch.mean(torch.sqrt(torch.sum((d_180)**2, dim=1)))
# d_270 = logit_s_all[3*batch_size:4*batch_size] - logit_s_all[:batch_size]
# loss_a += torch.mean(torch.sqrt(torch.sum((d_270)**2, dim=1)))
if(torch.isnan(loss_a).any()):
break
else:
loss = loss_div + opt.gamma*loss_a / 3
- 使用教师模型和学生模型分别对总的输入
inputs_all
进行前向传播。 - 使用蒸馏损失函数计算知识蒸馏的损失
loss_div
,通过对比学生模型的预测结果和教师模型的预测结果。 - 计算一个角度(90度)上的旋转差异损失
loss_a
,即学生模型在输入上进行旋转后与原始输入之间的差异。 - 如果
loss_a
中存在NaN值,跳出循环。 - 否则,计算最终的损失
loss
,包括蒸馏损失和旋转差异损失。
计算准确率
acc1, acc5 = accuracy(logit_s_all[:batch_size], targets, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(acc1[0], inputs.size(0))
top5.update(acc5[0], inputs.size(0))
- 使用
accuracy()
函数计算学生模型的top1准确率和top5准确率。 - 更新损失、top1准确率和top5准确率的平均值。
optimizer.zero_grad()
loss.backward()
optimizer.step()
- 清空优化器的梯度。
- 反向传播计算梯度。
- 更新参数。
batch_time.update(time.time() - end)
end = time.time()
pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()),
"Acc@5":'{0:.2f}'.format(top5.avg.cpu().numpy(),2),
"Loss" :'{0:.2f}'.format(losses.avg,2),
})
- 更新每批数据的耗时。
- 更新当前时间。
- 使用
tqdm
库更新进度条的后缀,包括top1准确率、top5准确率和损失。
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg, losses.avg
- 打印最终的训练结果,包括top1准确率和top5准确率。
- 返回top1准确率和训练损失。
if __name__ == '__main__':
main()
程序入口