Learning to Continually Learn论文笔记+代码解读

论文地址点这里

一. 介绍

在终身学习的场景下,灾难性遗忘成为模型参数遇到的最大的问题。针对这个问题,其中一个解决方法是使用重放方法。通过保存以前的经验,将它与新的数据混合来进行交叉来训练遗忘。然后保存这些信息需要的计算和存储消耗较大。另一种是采取冻结权重,直观思想是冻结那些容易导致结果变化较大的权重。EWC作为一个最基本的基于冻结权重的方式,根据计算fisher信息来进行对权重的衡量。其他类似的方法也是试图通过特定任务的突触重要性,通过对任务的权重变化使用L2正则化。还有一种方式是直接激励创建最稀疏的表示,目标是最大限度地减少激活之间的干扰。这种稀疏表示间接地影响了反向传播中更新哪些参数。
而作者提出的方案是直接优化不遗忘学习,没有去设计一个解决灾难性遗忘的方案,也没有添加减轻灾难性遗忘的辅助损失。通过元学习方式让网络能够不断学习。本文引入了一个新的网络架构——该网络根据输入,通过另一个网络的激活来控制,从而导致第二个网络的选择性激活(这里也就是说有一个网络根据输入会输出一项0和正数组成的“门”,这个门作为一种激活函数处理我们分类网络的结果,从而使得部分权重被筛选)。这个随机门控取决于任务而不是数据本身。在训练中,我们不需要告诉网络他正在接受哪个任务的训练或测试,通过远远学习激活门。具体的例子如下
在这里插入图片描述
蓝色的网络为神经调节网络,通过对预测网络的前向传播中的激活产生一个以单元为单位的门限,使得选择性激活成为可能,并通过影响预测网络的梯度更新间接地使选择性可塑性成为可能。

二. 问题定义

首先针对元学习。元学习包括inner loop和outer loop。我们的目标是进行一些outer loop来使得inner loop的学习更好。通过outer loop来优化改善inner loop的学习能力称为元训练。而当元训练训练结束后,测试inner loop的过程称为元测试。在元训练过程中,对于其中的每一个inner loop都会学习到一些信息,这一过程称为元训练的训练过程(例如,在元训练中,inner loop基于MNIST数据集而执行几次梯度更新这个就称为元训练的训练过程)。在每一个inner loop的迭代后,每一个inner loop的训练者都必须进行评估,这一过程称为元训练的测试部分。而元训练的测试损失作为元训练损失进行最小化。在完成元训练后,我们转到元测试,需要评估元学习的学习者的表现。同样我们也分为训练和测试阶段。(这里不懂得可以看看我之前讲到的元学习部分,通俗意义来说,就是元学习是学习一种网络架构或者参数初始化信息。针对不同任务,我们每个任务的参数不同,但我们需要相仿设法去学习到一个统一的信息,而这个信息将针对每个任务形成独立的结构。因此在元训练阶段,我们首先会用元学习的参数去更新对应任务的结构,然后每个网络独自训练自己的任务参数,这里就是inner loop,inner loop反映出这个模型接收到的结构训练的好坏,之后我们再进行outer loop的更新)。
而如果我们的类的数量特别的多,普通的元学习框架将无法很好的运行。而作者提出的OML方式通过评估是否学习到了新知识以及旧知识是否丢失这两个指标来更新。在每个新遇到的类之后,元损失被计算为新学习的类上的错误和所有元训练类中随机样本上的错误。

三. 基于神经调节的元学习算法(ANML)

为了解决灾难性遗忘,作者提出了一个神经调节网络(也就是上面所述的激活门)。这种设置允许正常网络中的不同子网用于不同类型的任务,并从中学习。

3.1 ANML的架构

ANML网络包括两个并行的网络:神经调节网络和预测网络。这两个网络的权值都在outer loop进行更新,但神经调节网络不会再inner loop进行更新。每个网络都有三个卷积层(每个层后面跟着一个规范层和一个全连接层,预测网络的最后一层和神经调节网络的最后一层输入大小相同。在前向传播时,神经调节输出被用来屏蔽预测网络的潜在表示(两者相乘))。所有的激活函数都为relu,除了门控器(也就是神经调节网络的最后一层)还经过一个sigmoid函数将值限定为0到1,这意味着在工作时,这个门控智能抑制网络的激活。

3.2 元训练过程

ANML算法由一个内循环嵌套在一个优化的外循环组成。在每一个内循环(inner loop)中,首先从初始化参数 θ P \theta^P θP进行拷贝,然后针对一个单独的Omniglot元训练类 T n \mathcal{T_n} Tn训练20的SGD的迭代: θ 0 P , θ 1 P , θ 2 P , . . . , θ 20 P \theta^P_0,\theta^P_1,\theta^P_2,...,\theta^P_{20} θ0P,θ1P,θ2P,...,θ20P
在这20轮迭代中,输入参数经过预测网络后与我们的神经调节网络参数 θ N M \theta^{NM} θNM形成的门控激活函数进行相乘来进行选择性激活。
在每次反向传播的过程中,预测网络的门设置自然会减少流向其权重子集的梯度,从而修改SGD的更新,从而产生选择性的可塑性。(此时 θ N M \theta^{NM} θNM不进行更新)。
在对这个元训练类进行20次连续的梗系之后,使用训练的结果 θ 20 P \theta^P_{20} θ20P以及元训练类集合(记忆集)中的随机样本的64个类实例进行预测,来计算元损失。这种元损失函数被称为在线感知元学习(OML)。之后我们通过SGD更新的20个步骤反向传播这个元损失,更新我们 θ p \theta^p θp θ N M \theta^{NM} θNM(外循环)。外循环是通过Adam进行完成的,而所有的内循环更新通过SGD固定的学习率完成的。
具体算法如图所示:
在这里插入图片描述

3.2 元测试阶段

在训练完之后我们对于元测试中的训练集进行微调,采用同样的方式首先对每个 θ P \theta^P θP进行更新k次,之后再进行元测试-测试集的评估。
算法如图:
在这里插入图片描述

四. 代码解读

官方代码点这里

4.1 元训练

4.1.1 ANML的网络结构

元训练的代码对应为mrcl_classfication.py,我们根据第三节提出的,先针对大家最疑惑的神经元调节网络和预测网络的结构开始说明。

nm_channels = 112
channels = 256
size_of_representation = 2304
size_of_interpreter = 1008
# =============== Separate network neuromodulation =======================

('conv1_nm', [nm_channels, 3, 3, 3, 1, 0]),
('bn1_nm', [nm_channels]),
('conv2_nm', [nm_channels, nm_channels, 3, 3, 1, 0]),
('bn2_nm', [nm_channels]),
('conv3_nm', [nm_channels, nm_channels, 3, 3, 1, 0]),
('bn3_nm', [nm_channels]),

('nm_to_fc', [size_of_representation, size_of_interpreter]),

# =============== Prediction network ===============================

('conv1', [channels, 3, 3, 3, 1, 0]),
('bn1', [channels]),
('conv2', [channels, channels, 3, 3, 1, 0]),
('bn2', [channels]),
('conv3', [channels, channels, 3, 3, 1, 0]),
('bn3', [channels]),
('fc', [1000, size_of_representation]),

可以看到神经元调节网络和预测网络都为三个卷积层+1个全连接。对于神经调节网络由输入x:[batch_size,3,28,28]变为输出mask:[batch_size,2304]。对于预测网络则是,输入输入x:[batch_size,3,28,28]变为输出前最后一层h:[batch_size,2304],这时候h*mask进行选择性神经元的调节,也就是降低权重,最后再输出y:[batch_size,1000]。(注意:这里的batch_size指的是对应的一个集有几张图片,例如batch_size=10,也就是说当前选取的集存在10张图片)
对应着来看一下网络是怎么工作的

def forward(self, x, vars=None, bn_training=True, feature=False):
    """
    This function can be called by finetunning, however, in finetunning, we dont wish to update
    running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
    Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
    but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
    :param x: [b, 1, 28, 28]
    :param vars:
    :param bn_training: set False to not update
    :return: x, loss, likelihood, kld
    """
    
    cat_var = False
    cat_list = []

    if vars is None:
        vars = self.vars
    idx = 0
    bn_idx = 0

    if self.Neuromodulation:

        # =========== NEUROMODULATORY NETWORK ===========

        #'conv1_nm'
        #'bn1_nm'
        #'conv2_nm'
        #'bn2_nm'
        #'conv3_nm'
        #'bn3_nm'

      
        # Query the neuromodulatory network:
        # 对每张图片进行分析,这里单独拆开每一张图片是因为support集和query集的batch_size不一样大
        for i in range(x.size(0)):
        	# 将数据复制,分别传入神经元调节网络和预测网络
            data = x[i].view(1,3,28,28)
            nm_data = x[i].view(1,3,28,28)
            # 神经元调节网络的第一个卷积层
            w,b = vars[0], vars[1]
            nm_data = conv2d(nm_data, w, b)
            w,b = vars[2], vars[3]
            running_mean, running_var = self.vars_bn[0], self.vars_bn[1]
            nm_data = F.batch_norm(nm_data, running_mean, running_var, weight=w, bias=b, training=True)

            nm_data = F.relu(nm_data)
            nm_data = maxpool(nm_data, kernel_size=2, stride=2)
			# 神经元调节网络的第二个卷积层
            w,b = vars[4], vars[5]
            nm_data = conv2d(nm_data, w, b)
            w,b = vars[6], vars[7]
            running_mean, running_var = self.vars_bn[2], self.vars_bn[3]
            nm_data = F.batch_norm(nm_data, running_mean, running_var, weight=w, bias=b, training=True)

            nm_data = F.relu(nm_data)
            nm_data = maxpool(nm_data, kernel_size=2, stride=2)
			# 神经元调节网络的第三个卷积层
            w,b = vars[8], vars[9]
            nm_data = conv2d(nm_data, w, b)
            w,b = vars[10], vars[11]
            running_mean, running_var = self.vars_bn[4], self.vars_bn[5]
            nm_data = F.batch_norm(nm_data, running_mean, running_var, weight=w, bias=b, training=True)
            nm_data = F.relu(nm_data)

            nm_data = nm_data.view(nm_data.size(0), 1008)

            # 神经元调节网络的最后一层也就是输出门控网络

            w,b = vars[12], vars[13]
            # num_data经过rulu变为包含0和正数的矩阵,再经过sigmoid变为0~1之间的矩阵从而可以抑制部分神经元
            fc_mask = F.sigmoid(F.linear(nm_data, w, b)).view(nm_data.size(0), 2304)


            # =========== PREDICTION NETWORK ===========
			#预测部分和神经元调节部分同理,经过三个卷积层
            #'conv1'
            #'bn1'
            #'conv2'
            #'bn2'
            #'conv3'
            #'bn3'
            #'fc'

            w,b = vars[14], vars[15]
        
            data = conv2d(data, w, b)

            w,b = vars[16], vars[17]
            running_mean, running_var = self.vars_bn[6], self.vars_bn[7]
            data = F.batch_norm(data, running_mean, running_var, weight=w, bias=b, training=True)
            data = F.relu(data)
            data = maxpool(data, kernel_size=2, stride=2)

            w,b = vars[18], vars[19]
        
            data = conv2d(data, w, b, stride=1)
            w,b = vars[20], vars[21]
            running_mean, running_var = self.vars_bn[8], self.vars_bn[9]
            data = F.batch_norm(data, running_mean, running_var, weight=w, bias=b, training=True)
            data = F.relu(data)
            data = maxpool(data, kernel_size=2, stride=2)
            
            w,b = vars[22], vars[23]

            data = conv2d(data, w, b, stride=1)
            w,b, = vars[24], vars[25]
            running_mean, running_var = self.vars_bn[10], self.vars_bn[11]
            data = F.batch_norm(data, running_mean, running_var, weight=w, bias=b, training=True)
            data = F.relu(data)
            #data = maxpool(data, kernel_size=2, stride=2)

            data = data.view(data.size(0), 2304) #nothing-max-max
            # 将我们的fc_mask*data进行神经元抑制
            data = data*fc_mask
            w,b = vars[26], vars[27]
            data = F.linear(data, w, b)

            try:
                prediction = torch.cat([prediction, data], dim=0)
            except:
                prediction = data
	return(prediction)

4.1.2 元训练的数据构成

在开始进行训练前,首先先对数据构成进行说明。根据元学习,训练部分的数据包含两部分,一个为support set进行学习,一个为query set进行测试,计算meta-loss。再ANML下,support set中作者只涉及到一个新类,这个新类包含20张图片,所以一个support set中x=[20,3,28,28],y=[20,1],其中每一个y的值都相同。而query set对应作者设定的remmeber set(记忆集),是随机选取了84张来自不同类的图片,也就是x=[84,3,28,28],y=[84],其中每一个y都不相同。这里我放一下运行过程的截图:
在这里插入图片描述
有了数据集,我们可以开始进行训练,下面为训练的框架

for step in range(args.steps):
	# 随机选取一个新的类
    t1 = np.random.choice(args.classes, args.tasks, replace=False)
	
    d_traj_iterators = []
    for t in t1:
        d_traj_iterators.append(sampler.sample_task([t]))

    d_rand_iterator = sampler.get_complete_iterator()
	# 构建support set和query set
    x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(d_traj_iterators, d_rand_iterator,
                                                           steps=args.update_step, reset=not args.no_reset)
    if torch.cuda.is_available():
        x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda(), x_qry.cuda(), y_qry.cuda()
	# 进行元训练
    accs, loss = maml(x_spt, y_spt, x_qry, y_qry)#, args.tasks)

    # Evaluation during training for sanity checks
    if step % 40 == 0:
        #writer.add_scalar('/metatrain/train/accuracy', accs, step)
        logger.info('step: %d \t training acc %s', step, str(accs))
    if step % 100 == 0 or step == 19999:
        torch.save(maml.net, args.model_name)
    if step % 2000 == 0 and step != 0:
        utils.log_accuracy(maml, my_experiment, iterator_test, device, writer, step)
        utils.log_accuracy(maml, my_experiment, iterator_train, device, writer, step)

4.1.3 元训练过程

有了网络,有了数据集,接下来我们对4.1.2中的训练过程maml进行分析。
根据算法流程,首先我们要进行inner loop的更新,简单回顾一下,inner loop包括三部分,首先是对初始化参数 θ P \theta^P θP进行复制,使得 θ 0 P = θ P \theta^P_0 = \theta^P θ0P=θP。接下来根据suppoert set数据计算损失,最后更新梯度计算出 θ 1 P , θ 2 P . . . \theta^P_1,\theta^P_2... θ1P,θ2P...。对应代码如下:

## 迭代更新fast_weight
fast_weights = self.inner_update(x_traj[0], None, y_traj[0], False)        
for k in range(1, self.update_step):
    # Doing inner updates using fast weights
    fast_weights = self.inner_update(x_traj[k], fast_weights, y_traj[k], False)

我们躯体看一看inner_update

def inner_update(self, x, fast_weights, y, bn_training):
	# 根据网络进行预测
    logits = self.net(x, fast_weights, bn_training=bn_training)
    # 计算损失
    loss = F.cross_entropy(logits, y)
	# 第一次直接赋值
    if fast_weights is None:
        fast_weights = self.net.parameters()
	# 计算梯度
    grad = torch.autograd.grad(loss, fast_weights, allow_unused=False)
	# 梯度更新,\theta = \theta - lr*grad
    fast_weights = list(
        map(lambda p: p[1] - self.update_lr * p[0] if p[1].learn else p[1], zip(grad, fast_weights)))
	# 注意这里,再inner loop中我们的神经调节网络是不进行更新的,因此要保证只有预测网络可以进行学习更新
    for params_old, params_new in zip(self.net.parameters(), fast_weights):
        params_new.learn = params_old.learn

    return fast_weights

inner update算完之后,我们进行outer loop的更新。outer loop相对来说就更简单了,首先根据inner loop中更新出来的fast_weight在query set上计算梯度,再对我们的 θ P 和 θ N M \theta^P和\theta^{NM} θPθNM进行梯度更新。对应代码如下:

meta_loss, logits = self.meta_loss(x_rand[0], fast_weights, y_rand[0], False)

# 计算准确率
with torch.no_grad():
    pred_q = F.softmax(logits, dim=1).argmax(dim=1)
    classification_accuracy = torch.eq(pred_q, y_rand[0]).sum().item()  # convert to numpy

# Taking the meta gradient step

self.net.zero_grad()
meta_loss.backward()

self.optimizer.step()

classification_accuracy /= len(x_rand[0])

self.meta_iteration += 1

meta_loss的计算如下:

def meta_loss(self, x, fast_weights, y, bn_training):

    logits = self.net(x, fast_weights, bn_training=bn_training)
    loss_q = F.cross_entropy(logits, y)
    return loss_q, logits

4.2 元测试阶段

元测试阶段相对来说比较简单,我们从训练中获得的 θ P \theta^P θP θ N M \theta^{NM} θNM进行更新,之后预测即可。
(这里要吐槽一下,作者写的这个元测试真的很乱,很多地方就直接复制粘贴,代码中出现大量冗余)
首先说一下数据划分,数据集中作者设定了一次取多少个类,一次进行微调,分别为(10,50,75,100,200,300,400,500,600)。意思是按照顺序先从数据集随机抽取10个类进行微调,微调后计算准确率,之后再取50个类,以此类推。值得注意的是,我们的场景时终身学习,各种任务是按照顺序到来的。(接下来我都以10个类作为讲解),这里的代码如下:

keep = np.random.choice(list(range(650)), tot_class, replace=False)
# 选10*15张图片
dataset = utils.remove_classes_omni(
    df.DatasetFactory.get_dataset("omniglot", train=True, background=False, path=args.dataset_path), keep)
## support set
iterator_sorted = torch.utils.data.DataLoader(
    utils.iterator_sorter_omni(dataset, False, classes=total_clases),
    batch_size=1,
    shuffle=args.iid, num_workers=2)
dataset = utils.remove_classes_omni(
    df.DatasetFactory.get_dataset("omniglot", train=not args.test, background=False, path=args.dataset_path),
    keep)
## query set 注意这里的suffle为False
iterator = torch.utils.data.DataLoader(dataset, batch_size=1,
                                       shuffle=False, num_workers=1)

当我们随机抽取了10个类,每个类有15张图片,所以对于测试中进行微调的suppoert为:[150,3,28,28]。训练代码如下:

for _ in range(0, args.epoch):
    for img, y in iterator_sorted:
        img = img.to(device)
        y = y.long()
        y = y.to(device)

        pred = maml(img)
        opt.zero_grad()
        loss = F.cross_entropy(pred, y)
        loss.backward()
        opt.step()

预测代码类似就不贴了。

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值