Project1 / SKD / train_distillation.py

目录

def get_freer_gpu():

class Wrapper(nn.Module):

def parse_option():

 parser

opt = parser.parse_args()

def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'):

def main():

调用opt,解析命令行参数并使用

获取训练集、验证集和元测试集的数据加载器

教师模型的加载

创建学生模型

交叉熵损失函数、知识蒸馏损失函数、随机梯度下降优化器

是否有可用的GPU设备及模型迁移到GPU

准确率、标准差的定义,调整学习率

验证集准确率,与评估结果

保存模型

def train

将学生模型设置为训练模式,教师模型设置为评估模式

创建了用于记录训练过程中各种指标的AverageMeter对象

从数据中获取输入和目标

对输入进行旋转增强

使用教师模型和学生模型分别对总的输入inputs_all进行前向传播

计算准确率


def get_freer_gpu():

def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(3944), int(4095)]
    # memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)

os.environ["CUDA_VISIBLE_DEVICES"]=str(get_freer_gpu())

这段代码的作用是获取当前可用的显存最大的 GPU 设备,并将该设备的索引号设置为 CUDA 可见设备的环境变量。

首先定义了一个名为 get_freer_gpu() 的函数。该函数通过运行系统命令 nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp,获取全部 GPU 设备可用显存大小,并将输出结果保存到名为 "tmp" 的文件中。接着,在内存可用数组中定义了两个整数,分别表示两个 GPU 设备的可用显存大小,这里提前定义的原因可能是部分 GPU 设备的可用显存大小是固定的,而无法通过命令获取。

然后执行 np.argmax(memory_available) 函数来获取当前可用显存最大的 GPU 设备的索引号,并返回该索引号,这个索引号表示在系统中GPU设备的编号。(这看起来不像设备索引号,像最大可用内存空间大小啊?)

最后,通过将该索引号转化为字符串并将其赋值给环境变量 "CUDA_VISIBLE_DEVICES",可以将该 GPU 设备设置为 CUDA 可见设备,让程序在运行时只使用该 GPU 设备进行计算,以充分利用 GPU 的计算加速能力。

class Wrapper(nn.Module):

class Wrapper(nn.Module):

    def __init__(self, model, args):
        super(Wrapper, self).__init__()
    
        self.model = model  # 看C++语法,这里已经忘了
        self.feat = torch.nn.Sequential(*list(self.model.children())[:-2])
        
        self.last = torch.nn.Linear(list(self.model.children())[-2].in_features, 64)       
        
    def forward(self, images):
        feat = self.feat(images)
        feat = feat.view(images.size(0), -1)
        out = self.last(feat)
        
        return feat, out
    

这个类名为 Wrapper,它是一个继承自 nn.Module 的子类,用于包装神经网络模型

Wrapper 类的构造函数 __init__() 中,接收两个参数 modelargsmodel 是一个已经定义好的神经网络模型,args 则可能是其他配置参数。

构造函数中的 super(Wrapper, self).__init__() 语句是调用父类 nn.Module 的构造函数进行初始化。# 这个没看懂,得复习一下C++语法

接下来,通过 self.model = model 将传入的模型赋值给 Wrapper 类的成员变量 self.model

然后,使用 torch.nn.Sequential() 构建了一个新的神经网络层序列 self.feat,该序列由 self.model 的所有子模型(即所有孩子节点)去掉最后两个子模型组成。可以理解为 self.featself.model 倒数第三个子模型的输出。

再之后,使用 torch.nn.Linear() 构造了一个线性层 self.last,该层的输入大小为 self.model 倒数第二子模型的输入特征大小,输出大小为 64。

forward() 方法定义了前向传播过程。给定输入 images,首先将其通过 self.feat 进行特征提取,得到特征 feat。然后将 feat 进行展平操作,正好可以将 images 的大小为 (batch_size, channels, height, width) 的张量转换成了 (batch_size, num_features)。接着,将展平后的特征 feat 作为输入传递给线性层 self.last,得到输出 out。最后,返回特征 feat 和输出 out

通过使用这个 Wrapper 类,可以方便地访问封装模型的不同层输出,尤其是在需要获取中间层特征时很有用。

def parse_option():

def parse_option():

    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency')
    parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
    parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
    parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    # dataset and model
    parser.add_argument('--model_s', type=str, default='resnet12', choices=model_pool)
    parser.add_argument('--model_t', type=str, default='resnet12', choices=model_pool)
    parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
                                                                                'CIFAR-FS', 'FC100'])
    parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
    parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning')
    parser.add_argument('--tags', type=str, default="gen1, ssl", help='add tags for the experiment')
    parser.add_argument('--transform', type=str, default='A', choices=transforms_list)

    # path to teacher model
    parser.add_argument('--path_t', type=str, default="", help='teacher model snapshot')

    # distillation
    parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'contrast', 'hint', 'attention'])
    parser.add_argument('--trial', type=str, default='1', help='trial id')

    parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
    parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for KD')
    parser.add_argument('-b', '--beta', type=float, default=0, help='weight balance for other losses')

    # KL distillation
    parser.add_argument('--kd_T', type=float, default=2, help='temperature for KD distillation')
    # NCE distillation
    parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
    parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
    parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
    parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')

    # cosine annealing
    parser.add_argument('--cosine', action='store_true', help='using cosine annealing')

    # specify folder
    parser.add_argument('--model_path', type=str, default='save/', help='path to save model')
    parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard')
    parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')

    # setting for meta-learning
    parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
                        help='Number of test runs')
    parser.add_argument('--n_ways', type=int, default=5, metavar='N',
                        help='Number of classes for doing each classification run')
    parser.add_argument('--n_shots', type=int, default=1, metavar='N',
                        help='Number of shots in test')
    parser.add_argument('--n_queries', type=int, default=15, metavar='N',
                        help='Number of query in test')
    parser.add_argument('--n_aug_support_samples', default=5, type=int,
                        help='The number of augmented samples for each meta test sample')
    parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
                        help='Size of test batch)')

    opt = parser.parse_args()

    if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        opt.transform = 'D'

    if 'trainval' in opt.path_t:
        opt.use_trainval = True
    else:
        opt.use_trainval = False

    if opt.use_trainval:
        opt.trial = opt.trial + '_trainval'

    # set the path according to the environment
    if not opt.model_path:
        opt.model_path = './models_distilled'
    if not opt.tb_path:
        opt.tb_path = './tensorboard'
    if not opt.data_root:
        opt.data_root = './data/{}'.format(opt.dataset)
    else:
        opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
    opt.data_aug = True
    
    tags = opt.tags.split(',')
    opt.tags = list([])
    for it in tags:
        opt.tags.append(it)
        
    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = 'S:{}_T:{}_{}_{}_r:{}_a:{}_b:{}_trans_{}'.format(opt.model_s, opt.model_t, opt.dataset,
                                                                      opt.distill, opt.gamma, opt.alpha, opt.beta,
                                                                      opt.transform)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    opt.model_name = '{}_{}'.format(opt.model_name, opt.trial)

    opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
    if not os.path.isdir(opt.tb_folder):
        os.makedirs(opt.tb_folder)

    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)
    
    #extras
    opt.fresh_start = True
    
    
    return opt

上述代码是一个用解析命令行参数的函数。该函数使用argparse库来解析参数,并返回一个包含解析结果的对象opt

解析器定义了一系列的参数,如eval_freq、print_freq、tb_freq等。每个参数都有自己的类型、默认值和帮助信息

在函数内部,使用parser.parse_args()方法对命令行参数进行解析,并将解析结果保存在opt对象中。

随后,根据一些特殊的逻辑对opt的某些属性进行了一些调整和赋值。例如,如果数据集是'CIFAR-FS'或'FC100',则将opt.transform设置为'D';如果opt.path_t中包含'trainval',则将opt.use_trainval设置为True。

最后,根据一些规则设置了模型名称、存储路径等其他属性,并返回了opt对象。

需要注意的是,上述代码只是对解析器的定义和一些参数属性的初始化操作并没有执行实际的参数解析。实际的参数解析需要在调用parse_option()函数时进行。

 parser

def parse_option():

    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency')
    parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
    parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
    parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    # dataset and model
    parser.add_argument('--model_s', type=str, default='resnet12', choices=model_pool)
    parser.add_argument('--model_t', type=str, default='resnet12', choices=model_pool)
    parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
                                                                                'CIFAR-FS', 'FC100'])
    parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
    parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning')
    parser.add_argument('--tags', type=str, default="gen1, ssl", help='add tags for the experiment')
    parser.add_argument('--transform', type=str, default='A', choices=transforms_list)

    # path to teacher model
    parser.add_argument('--path_t', type=str, default="", help='teacher model snapshot')

    # distillation
    parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'contrast', 'hint', 'attention'])
    parser.add_argument('--trial', type=str, default='1', help='trial id')

    parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
    parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for KD')
    parser.add_argument('-b', '--beta', type=float, default=0, help='weight balance for other losses')

    # KL distillation
    parser.add_argument('--kd_T', type=float, default=2, help='temperature for KD distillation')
    # NCE distillation
    parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
    parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
    parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
    parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')

    # cosine annealing
    parser.add_argument('--cosine', action='store_true', help='using cosine annealing')

    # specify folder
    parser.add_argument('--model_path', type=str, default='save/', help='path to save model')
    parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard')
    parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')

    # setting for meta-learning
    parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
                        help='Number of test runs')
    parser.add_argument('--n_ways', type=int, default=5, metavar='N',
                        help='Number of classes for doing each classification run')
    parser.add_argument('--n_shots', type=int, default=1, metavar='N',
                        help='Number of shots in test')
    parser.add_argument('--n_queries', type=int, default=15, metavar='N',
                        help='Number of query in test')
    parser.add_argument('--n_aug_support_samples', default=5, type=int,
                        help='The number of augmented samples for each meta test sample')
    parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
                        help='Size of test batch)')

定义了许多参数。

这些参数可以从命令行接收,也可以在代码中进行硬编码或通过其他方式进行设置。根据 parser.add_argument() 的用法,这些参数可以在命令行中使用 --参数名 参数值 的形式进行设置,例如 --model_path save/

除了命令行参数外,还可以通过其他方式来设置这些参数的默认值。例如,在代码中直接为这些参数赋予默认值,如 default='save/'default='tb/'default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/'。这样在运行代码时,如果没有通过命令行传递这些参数,则会使用默认值

另外,你还可以通过其他方式传递这些参数,比如从配置文件中读取,或者通过函数调用时传递参数值。总之,argparse.ArgumentParser 可以用于解析命令行参数,但它并不限制参数只能从命令行接收,你可以根据自己的需求选择不同的方式来设置这些参数的值。

opt = parser.parse_args()

    opt = parser.parse_args()

    if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        opt.transform = 'D'

    if 'trainval' in opt.path_t:
        opt.use_trainval = True
    else:
        opt.use_trainval = False

    if opt.use_trainval:
        opt.trial = opt.trial + '_trainval'

    # set the path according to the environment
    if not opt.model_path:
        opt.model_path = './models_distilled'
    if not opt.tb_path:
        opt.tb_path = './tensorboard'
    if not opt.data_root:
        opt.data_root = './data/{}'.format(opt.dataset)
    else:
        opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
    opt.data_aug = True

这段代码中首先通过 parser.parse_args() 解析命令行参数,将所需的参数值存储在一个名为 opt 的对象中

接下来,通过判断 opt.dataset 是否为 'CIFAR-FS''FC100' 来设置 opt.transform'D'。这里的作用是区分数据集不同,因为对于 'CIFAR-FS''FC100' 数据集来说,数据增强的方式不同,'D' 方式适用于这两个数据集。

然后,通过判断 opt.path_t 中是否包含 'trainval' 来设置 opt.use_trainvalTrueFalse。如果包括,则说明使用包含训练和验证集的路径(即在训练过程中使用验证集),否则只使用单独的训练集。

opt.use_trainval=True 的情况下,将 opt.trial 加上 _trainval 的后缀,以表示此时的 trial 是使用训练和验证集的。

然后,通过检查 opt.model_pathopt.tb_pathopt.data_root 是否已经定义,如果没有定义,则分别赋予默认值 './models_distilled''./tensorboard'./data/{opt.dataset}。如果已经定义,则根据已有的 opt.dataset 将其设置为合适的路径。

最后,设置 opt.data_augTrue,指定对数据使用数据增强的方式。

def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'):

def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'):
    """load the teacher model"""
    print('==> loading teacher model')
    print(model_name)
    model = create_model(model_name, n_cls, dataset)
    model.load_state_dict(torch.load(model_path)['model'])
    print('==> done')
    return model

上述代码是一个加载教师模型的函数。函数接受以下参数:

  • model_path:教师模型的路径。
  • model_name:教师模型的名称。
  • n_cls:分类任务的类别数。
  • dataset:使用的数据集,默认为'miniImageNet'。

函数首先打印一条信息,表示正在加载教师模型,并输出模型名称。

然后,通过调用create_model()函数创建一个模型实例,传入模型名称和分类任务的类别数。create_model()函数根据模型名称和类别数选择合适的模型,并返回模型实例

接下来,使用torch.load()函数加载保存在model_path中的教师模型的权重,并将权重加载到模型实例中的state_dict中。

最后,函数打印一条信息,表示模型加载完成,并返回加载后的模型实例。

需要注意的是,上述代码假设教师模型的权重是以字典形式保存在model_path中的,并且字典中的键为'model'。如果教师模型的权重保存方式不同,需要进行相应的修改。

def main():

调用opt,解析命令行参数并使用

def main():
    best_acc = 0

    opt = parse_option()
    wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags)
    wandb.config.update(opt)
    wandb.save('*.py')
    wandb.run.save()
  • best_acc:用于记录最佳的准确率。
  • opt = parse_option():调用parse_option()函数,解析命令行参数并返回一个配置对象。前面只是定义命令行参数,到这里真的进行命令行参数解析了)
  • wandb.init():初始化Weights & Biases,将模型路径和标签作为项目名称和标签进行设置。
  • wandb.config.update(opt):更新Weights & Biases的配置,将配置对象传递给它。
  • wandb.save('*.py'):保存代码文件到Weights & Biases的运行目录。
  • wandb.run.save():保存Weights & Biases的运行状态。
    # dataloader
    train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt)

获取训练集、验证集和元测试集的数据加载器

  • get_dataloaders(opt):根据配置对象opt获取训练集、验证集和元测试集的数据加载器,并返回它们以及类别数量n_cls。

教师模型的加载

    # model
    model_t = []
    if("," in opt.path_t):
        for path in opt.path_t.split(","):
            model_t.append(load_teacher(path, opt.model_t, n_cls, opt.dataset))
    else:
        model_t.append(load_teacher(opt.path_t, opt.model_t, n_cls, opt.dataset))
  • 创建一个空列表model_t用于存储教师模型。
  • 如果opt.path_t中包含逗号,则按逗号分割路径并加载多个教师模型,然后将它们添加到model_t列表中。
  • 否则,只加载一个教师模型并将其添加到model_t列表中。
  • load_teacher是之前定义的一个函数。

创建学生模型

    model_s = copy.deepcopy(model_t[0])
  • 创建一个学生模型,并使用copy.deepcopy()函数深度复制model_t列表中的第一个教师模型。

交叉熵损失函数、知识蒸馏损失函
数、随机梯度下降优化器

    criterion_cls = nn.CrossEntropyLoss()
    criterion_div = DistillKL(opt.kd_T)
    criterion_kd = DistillKL(opt.kd_T)

    optimizer = optim.SGD(model_s.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
  • 创建交叉熵损失函数criterion_cls
  • 基于配置对象中的温度参数kd_T创建教师-学生之间的知识蒸馏损失函数criterion_divcriterion_kd
  • 创建随机梯度下降(SGD)优化器,并设置学习率、动量和权重衰减。

是否有可用的GPU设备及模型迁移到GPU

    if torch.cuda.is_available():
        for m in model_t: 
            m.cuda()
        model_s.cuda()
        criterion_cls = criterion_cls.cuda()
        criterion_div = criterion_div.cuda()
        criterion_kd = criterion_kd.cuda()
        cudnn.benchmark = True
  • 检查是否有可用的GPU设备。
  • 如果是,将教师模型和学生模型移动到GPU上。
  • 将损失函数也移动到GPU上。
  • 启用cudnn加速。

准确率、标准差的定义,调整学习率

    meta_test_acc = 0 
    meta_test_std = 0
  • 初始化元测试集的准确率和标准差为0。
    for epoch in range(1, opt.epochs + 1):
        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
  • 根据配置对象中的cosine参数选择调整学习率的策略:如果为True,则使用余弦退火调整学习率;否则,调用adjust_learning_rate()函数根据当前epoch来调整学习率。
  • 打印训练开始的提示信息。
  • 记录训练开始的时间。
  • 调用train()函数进行模型训练,并获取训练集的准确率和损失。
  • 记录训练结束的时间,并打印本轮训练的总时间。

验证集准确率,与评估结果

        val_acc = 0
        val_loss = 0
        meta_val_acc = 0
        meta_val_std = 0
  • 初始化验证集准确率和损失、元验证集准确率和标准差为0。
        start = time.time()
        meta_test_acc, meta_test_std = meta_test(model_s, meta_testloader, use_logit=False)
        test_time = time.time() - start
        print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.format(meta_test_acc, meta_test_std, test_time))
  • 记录当前时间。
  • 调用meta_test()函数对学生模型在元测试集上进行评估,并获取元测试集的准确率和标准差。
  • 计算评估过程的时间并打印结果。

保存模型

        if epoch % opt.save_freq == 0 or epoch==opt.epochs:
            print('==> Saving...')
            state = {
                'epoch': epoch,
                'model': model_s.state_dict(),
            }            
            save_file = os.path.join(opt.save_folder, 'model_'+str(wandb.run.name)+'.pth')
            torch.save(state, save_file)
            
            #wandb saving
            torch.save(state, os.path.join(wandb.run.dir, "model.pth"))
  • 如果当前epoch是保存频率的倍数,或者是最后一个epoch,则保存模型。
  • 创建一个状态字典,包含当前epoch和学生模型的状态字典。
  • 构建保存文件的路径,并保存模型。
  • 还将模型保存到Weights & Biases的运行目录下。
        wandb.log({'epoch': epoch, 
                   'Train Acc': train_acc,
                   'Train Loss':train_loss,
                   'Val Acc': val_acc,
                   'Val Loss':val_loss,
                   'Meta Test Acc': meta_test_acc,
                   'Meta Test std': meta_test_std,
                   'Meta Val Acc': meta_val_acc,
                   'Meta Val std': meta_val_std
                  })     

使用Weights & Biases的wandb.log()函数记录各种指标和损失。

    generate_final_report(model_s, opt, wandb)

调用generate_final_report()函数生成最终报告。

    output_log_file = os.path.join(wandb.run.dir, "output.log")
    if os.path.isfile(output_log_file):
        os.remove(output_log_file)
    else:
        print("Error: %s file not found" % output_log_file)
  • 设置日志文件路径。
  • 如果存在该文件,则删除日志文件。
  • 否则,打印错误消息。

def train

def train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt):
    """One epoch training"""
  • 定义了一个名为train()的函数,接受epoch数、训练集数据加载器、学生模型、教师模型、分类标准损失函数、蒸馏损失函数、知识蒸馏损失函数、优化器和配置对象作为参数。
  • 函数的注释指出该函数用于进行一个epoch的训练。

将学生模型设置为训练模式,教师模型设置为评估模式

    model_s.train()
    for m in model_t:
        m.eval()
  • 将学生模型设置为训练模式,通过调用model_s.train()
  • 将所有的教师模型设置为评估模式,通过循环遍历所有的教师模型并调用model.eval()

创建了用于记录训练过程中各种指标的AverageMeter对象

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

创建了用于记录训练过程中各种指标的AverageMeter对象:batch_time(记录每批数据的耗时)、data_time(记录数据加载的耗时)、losses(记录训练损失)、top1(记录top1准确率)、top5(记录top5准确率)。

    end = time.time()
    
    with tqdm(train_loader, total=len(train_loader)) as pbar:
        for idx, data in enumerate(pbar):
  • 记录当前时间。
  • 使用tqdm库创建了一个进度条,并迭代训练集数据加载器。

从数据中获取输入和目标

            inputs, targets, _ = data
            data_time.update(time.time() - end)
  • 从数据中获取输入和目标。
  • 更新数据加载的耗时,通过计算当前时间与上一步记录的时间差。
            inputs = inputs.float()
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                targets = targets.cuda()
  • 将输入和目标转换为浮点型数据。
  • 如果有可用的CUDA设备,将输入和目标移动到CUDA设备上。

对输入进行旋转增强

            batch_size = inputs.size()[0]
            x = inputs
            
            x_90 = x.transpose(2,3).flip(2)
            x_180 = x.flip(2).flip(3)
            x_270 = x.flip(2).transpose(2,3)
            inputs_aug = torch.cat((x_90, x_180, x_270),0)
            
            
            sampled_inputs = inputs_aug[torch.randperm(3*batch_size)[:batch_size]]
            inputs_all = torch.cat((x, x_180, x_90, x_270),0)
  • 获取批次大小。
  • 对输入进行旋转增强:分别将输入旋转90度、180度和270度,并进行拼接,形成增强后的输入inputs_aug
  • 通过随机采样从增强输入中选择batch_size个样本,保存在sampled_inputs中。
  • 将原始输入和所有增强后的输入进行拼接,形成总的输入inputs_all

使用教师模型和学生模型分别对总的输入inputs_all进行前向传播

            with torch.no_grad():
                (_,_,_,_, feat_t), (logit_t, rot_t) = model_t[0](inputs_all[:batch_size], rot=True)

            (_,_,_,_, feat_s_all), (logit_s_all, rot_s_all)  = model_s(inputs_all[:4*batch_size], rot=True)
            
            loss_div = criterion_div(logit_s_all[:batch_size], logit_t[:batch_size])

            d_90 = logit_s_all[batch_size:2*batch_size] - logit_s_all[:batch_size]
            loss_a = torch.mean(torch.sqrt(torch.sum((d_90)**2, dim=1)))
#             d_180 = logit_s_all[2*batch_size:3*batch_size] - logit_s_all[:batch_size]
#             loss_a += torch.mean(torch.sqrt(torch.sum((d_180)**2, dim=1)))
#             d_270 = logit_s_all[3*batch_size:4*batch_size] - logit_s_all[:batch_size]
#             loss_a += torch.mean(torch.sqrt(torch.sum((d_270)**2, dim=1)))


            if(torch.isnan(loss_a).any()):
                break
            else:
                loss = loss_div + opt.gamma*loss_a / 3
  • 使用教师模型和学生模型分别对总的输入inputs_all进行前向传播。
  • 使用蒸馏损失函数计算知识蒸馏的损失loss_div,通过对比学生模型的预测结果和教师模型的预测结果。
  • 计算一个角度(90度)上的旋转差异损失loss_a,即学生模型在输入上进行旋转后与原始输入之间的差异。
  • 如果loss_a中存在NaN值,跳出循环。
  • 否则,计算最终的损失loss,包括蒸馏损失和旋转差异损失。

计算准确率

            acc1, acc5 = accuracy(logit_s_all[:batch_size], targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1[0], inputs.size(0))
            top5.update(acc5[0], inputs.size(0))
  • 使用accuracy()函数计算学生模型的top1准确率和top5准确率。
  • 更新损失、top1准确率和top5准确率的平均值。
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
  • 清空优化器的梯度。
  • 反向传播计算梯度。
  • 更新参数。
            batch_time.update(time.time() - end)
            end = time.time()
            
            pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()), 
                              "Acc@5":'{0:.2f}'.format(top5.avg.cpu().numpy(),2), 
                              "Loss" :'{0:.2f}'.format(losses.avg,2), 
                             })
  • 更新每批数据的耗时。
  • 更新当前时间。
  • 使用tqdm库更新进度条的后缀,包括top1准确率、top5准确率和损失。
    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, losses.avg
  • 打印最终的训练结果,包括top1准确率和top5准确率。
  • 返回top1准确率和训练损失。
    
if __name__ == '__main__':
    main()

程序入口

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值