CVPR 2021 | 强化学习太脆弱?VAI: 用注意力和不变性来让像素输入的强化学习更加稳定...

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

Unsupervised Visual Attention and Invariance for Reinforcement Learning

作者:Xudong Wang* (UC Berkeley / ICSI), Long Lian* (UC Berkeley / ICSI), Stella X. Yu (UC Berkeley / ICSI) (其中*代表共同一作)

论文链接:

https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Unsupervised_Visual_Attention_and_Invariance_for_Reinforcement_Learning_CVPR_2021_paper.pdf

代码链接:

https://github.com/TonyLianLong/VAI-ReinforcementLearning

Poster:

https://drive.google.com/file/d/1jnNUhp_S-EOcmPbCZU9WnaVKXpN4-TK-/view

强化学习 (Reinforcement Learning; 简称 RL) 这个概念相信大家一定都不陌生。通过给出一个Reward数值而非直接让模型来模仿决策,使用强化学习的模型能学出各种惊人的决策方式。DeepMind的AlphaGo就运用了强化学习来让模型可以战胜人类选手,腾讯训练的的王者荣耀觉悟AI也是一个基于强化学习训练的模型,它一天的训练就抵得上人类选手400年的训练水平。

AlphaGo计算棋子移动的相关概率 (Image Credit: https://deepmind.com/research/case-studies/alphago-the-story-so-far)

视觉强化学习

许多经典的强化学习模型使用人类手动编码的信息作为模型的输入,例如对于围棋,我们可以定义两方剩余棋子的数量差作为一个输入,这样我们就能间接捕捉双方目前的获胜概率,以此来进行下一步决策。然而并不是所有的环境都适合这种情况,就例如机器人控制领域,我们很可能只有一个摄像头来作为输入,在这种情况下从摄像头的像素级输入里提取有用的信息就变得有些困难了。这些年来,随着神经网络的发展,很多强化学习模型抛弃了原来的手工选定feature输入的形式,转而使用end-to-end的方式直接以图像作为输入来训练模型。DQN就是这样的一个例子,它直接使用图像输入在Atari游戏机里面使用RL打游戏,这种直接以视觉作为输入的强化学习就叫做视觉强化学习(Vision-based RL),与之相对的使用低维度的特征来训练的模型训练方法叫做State-based RL。

基于DQN打Atari游戏Pong (Image Credit: https://towardsdatascience.com/welcome-to-deep-reinforcement-learning-part-1-dqn-c3cab4d41b6b)

视觉强化学习的一些问题

尽管神经网络的出现让基于视觉的强化学习获得了一定的效果提升,可是在实际生活中使用视觉强化学习算法依然较为困难。为什么呢?除了强化学习本身的训练需要大量的样本来收敛以外,另一个严峻的问题是强化学习模型在有训练过程中没见到过的干扰的环境下极为不稳定。即便是我们使用已经在稳定性方面做过很多提升的SAC算法,决策的准确性依然容易被一些测试环境中的轻微干扰严重影响。我们这里做了一个实验:在经典的DeepMind Control的几个测试环境中,即使是仅针对背景颜色进行变换(前景保持不变),也会造成整个模型无法完成所需要的任务。

我们的目标:让基于视觉的强化学习agent可以在不同的测试环境都拥有很强的泛化能力。

目前的baseline在测试环境不同或者遇到微弱干扰的情况下极易崩溃,哪怕这些干扰与实际任务并非强相关。

目前也有很多的研究人员在尝试解决上述问题。现有的解决方案大多是想办法在训练过程中或者部署过程中针对测试环境变化校正模型,从而学出一个通用的RL模型(Universal RL Model)来保持鲁棒性。就比如我们想让模型在所有的背景颜色下面都表现良好,就在训练的时候让模型遍历所有的背景颜色。这样的方法(Domain Randomization)有一定的作用,可是也有些问题:首先,如果我们想让模型在多种环境下具有鲁棒性,例如背景颜色+阴影+背景中与前景相似的物体,那么我们就需要在多种环境中学习。然而,这一方案将导致环境场景的排列组合数量指数级增长(10种是或者否的情况则有1024种组合情况需要学习);其次,在实际的情况下使用这种方法会产生很严重的训练稳定性问题。因为我们给神经网络的训练环境加了扰动,而神经网络要学出有意义的行为就需要去克服这些扰动带来的学习障碍,这样一来神经网络就需要大量的训练样本来收敛,也就是我们所说的高样本复杂度(Sample Complexity)。

现有的方法主要着眼于获得能泛化至所有情况的通用模型,然而随着视觉干扰强度和种类的增加,这种训练方式会大幅度降低RL的训练稳定性和提升样本复杂度。

我们的思考和解决方案:VAI (·)

就拿上面提到的DeepMind Control的Walker举例子,Walker默认背景是天空。在Walker环境里的RL模型很容易被在背景中的变化所影响而摔倒,那为什么一个人走路的时候不会因为背景的变化而摔倒呢?答案其实比较明显:因为人走路的时候注意力集中在路上,会去判断是否有障碍物,而不会去判断天空中星星闪了没有。那我们是否也可以使用类似的方法来增强强化学习的泛化能力呢?我们希望提出一个模型,它并非着眼于Policy或是Feature extractor部分模型的泛化能力,而是使用一个专门的模块对视觉干扰信息进行排除,进而得到一个Distraction-invariant observation space,用这个去除干扰的输入来进行决策,从而不被外界的环境变化所干扰,我们把这个模块记为一个函数,记作VAI (·)。

我们的方法为训练一个单独的模块 (VAI),过滤和任务无关的视觉信息,从而让RL专注于任务有关的视觉信息(在测试以及训练过程中)。因为RL在训练过程中并未接触到干扰信息,因而收敛速度较快,并且对于数据数量的要求也更低。

无监督关键点检测和信息提取

上面我们提到的VAI(·) 模块其实对模型的最终表现很重要,因为不正确的输入(未能完整过滤背景信息或者错误地去掉了前景信息)会导致模型不正确的表现(如果VAI是一个identity function,那么模型就只具有普通模型的抗干扰能力了;如果VAI过滤掉所有视觉信息,那么RL模型获取的信息量为零;这两种情况自然是我们不希望看到的)。我们当然可以使用人工标注+有监督的方式来学习我们的VAI模块,然而人工标注有使用成本,并且并非在所有情况下都合适(例如我们希望模型在机器人上能够自己适应环境和自己训练)。因此我们通过无监督进行信息提取与训练。在这里我们会进行简单的描述,具体的实现的细节大家可以依照原文。

我们的Stage 1: 无监督关键点和特征提取模块

我们的Pipeline分为3个stage。首先是我们的关键点检测部分,如上图所示。在我们的环境里,我们假设前景的变化远远大于背景的变化。这个假设在很多时候是成立的,就比如工厂里机器人的运动的帧与帧之间的变化远远大于背景墙因为日光的颜色的变化。基于这个假设,我们使用了KeyNet Ψ(·) 来提取两帧之间的关键点。对于相邻的两帧,我们把这两帧记为Os和Ot。我们先使用encoder把Os 和 Ot转换为latent space内的表示,记为Φ(Os)和Φ(Ot),接下来使用KeyNet在Φ(Os)和Φ(Ot)分别提取K个模型认为是关键点的点(这里说的点并不是像素点,而是使用Gaussian来定义的有大小的点,记为G(µ; x)):

之后,我们再使用transport函数把Φ(Os)里面这2 * K个点的位置归零,再把Φ(Ot)里面的K个点移入刚才处理后的Φ(Os):

最后再使用decoder来reconstruct Ot,计算Reconstruction Loss:

如果KeyNet训练后可以准确提取两帧里面有差异的关键点,那么通过decoder就可以准确地reconstruct Ot,这也就是说KeyNet找到了我们需要的信息而丢弃了不需要的背景信息;如果KeyNet捕捉到的是不运动的部分,而没有捕捉运动的部分(两帧间的差异),reconstruction就是不准确的。我们通过优化来减小reconstructed t和真实Ot两者的差值。我们训练出Encoder, Decoder和KeyNet分别用于获取feature map、重建图像、以及捕捉前景关键点。

利用提取的信息进行模型的训练

我们获取KeyNet之后其实已经可以直接使用关键点的位置来进行RL训练了,就如上面所举的围棋的例子,state-based RL模型可以直接利用位置信息作为输入,之后进行训练。然而实际我们发现这样效果并不好,我们发现,这些关键点虽然能cover所有的前景信息,但是每一个关键点的位置和某一个特定信息无法建立一致的相关联系。例如下图第二行的绿色和粉色点,在t=0的时候所在的位置和t=80或者160的时候的位置代表着不同的信息。光看单个点其实不能很好地为RL模型提供决策所需的位置信息(例如agent的姿态)。

上图:通过我们的方法提取出来的前景信息,可以看到背景(包括在地面上的影子)已经几乎不可见,对这个部分取阈值以后就得到了我们的Attention Mask; 下图:如果简单地直接使用关键点本身,相同关键点在不同时刻并不对应于相同的位置。因而,无法在各个时刻给出一致性的信息。

因此我们采取了一个相对来说比较巧妙的输入方式:我们把KeyNet的输出按照通道融合起来,这样就消除了每一个关键点在不同的时间代表的语义不同的问题。我们再把融合起来的关键点和encoder的输出做element-wise product,这样就过滤掉了关键点以外的信息,然后我们最后用decoder输出解码后的图像原维度的输出。我们发现这样的输出还是带有背景颜色,即便对应位置decoder输入已经为0。我们发现这是由于decoder里面的bias导致的。然而模型里的bias不能直接设置为0,因为模型里面有很多non-linear的成分。为了解决这个问题,我们使用了Causal Inference来减去Null Input输入的模型输出来得到我们关键点的实际作用,在Causal Inference的框架中,这称为CDE(Controlled Direct Effect),其中At代表我们实际的输入,A0代表Null Input,M代表我们控制住Model bias不变:

最后为了让输出更鲁棒,我们通过一个阈值来区分前景背景,得到一个遮罩层。这也就是下图所示的第2步,这里的输出会作为Ground Truth使用。第二步模型的输出(在取阈值前)如上图第一行所示。遮罩层的公式为:

我们的 Stage 2:使用我们的方法提取出Attention Mask的过程。

使用提取的前景信息来训练VAI (·) 模块

最后就是我们的VAI (·) 模块训练的时候了,我们根据前面的遮罩层来生成Is和It,其中Is含有前景信息,同时我们还通过数据增强或者随机从外部数据集选取图片的方式来模拟了背景的噪声;It则只包含前景信息。Is合成部分的流程如下图所示,我们这里定义三个Augmentation函数,其中是Synchronized Random Crop,分别是背景和前景的Augmentation(见下图示例)。前景的Augmentation如颜色随机调整,背景的Augmentation如随机背景颜色,随机增加彩色方块,随机噪声和随机与之前的原生背景混合等。生成Augmented Image的整体公式为:

其中就是Stage 2的输出的Mask,的定义如下:

我们的Stage 3:我们使用背景Augmentation的方式来学习一个VAI(·) 模块,用它来过滤背景的噪声。我们把VAI(·) 设计为一个轻量级的模块,所以让它输出一个Mask而不是整张图片,这个Mask和原图做element-wise product以后就得到了消除干扰的图片。

如上图所示,在Stage 3里面我们希望模型可以根据Is来reconstruct stage 2生成的binary mask D(Ot),从而区分出前景。我们使用了2个loss来进行训练,一个是用来匹配encoder 的输出的Feature Matching Loss,一个是Is过了decoder后输出的Reconstruction Loss,整个Encoder和Decoder起到了一个过滤背景噪声的作用。经过训练,我们让这个模块能从复杂的前背景信息里面提取有用的前景信息,而防止背景信息输入RL,从而导致RL在训练或测试过程中的不稳定现象。Loss公式:

这里的VAI模块在训练之后就可以用于生成mask,提供给RL模型在训练和测试的时候用于去除背景信息。RL在整个训练和预测的过程中不会见到干扰信息,也就不会有需要拟合所有干扰或者在预测时被干扰的问题。

我们的方法的最终效果

我们的方法在几种不同的干扰上比Baseline更加稳定

我们在Deepmind Control Benchmark上面进行了背景颜色、相似物体和视频背景测试;我们也提出了基于MetaWorld的DrawerWorld Benchmark来测试模型实际在真实纹理上的效果,并在上面进行测试。所有的测试环境以及背景信息均在训练时不可见。效果如下图所示,我们的方法比之前的方法(如PAD)在测试的时候运行速度更快,performance也有大幅度的提升。

下面是在两个测试环境里面的具体结果。相较于之前的state-of-the-art,VAI在DrawerWorld Benchmark上提升了61%~229% 的cumulative reward,在DeepMind Control Benchmark上获得了15%~49%的相对提升。因而VAI具有更强的抗干扰能力。即使在训练的时没有见过测试过程中的干扰信息,VAI依然可以获得较高的cumulative reward。

DeepMind Control (相较于SOTA提升近53%):

其中带有“P”的两组实验指的是使用了外部数据集(Places)进行增强的结果(VAI的使用外部数据集方式就是把它作为步骤3里面的Is,让模型忽略真实情况的背景),没有带P的实验指的是没有使用外部数据集来增强的结果,使用/不使用外部数据集的结果可以分别进行公平比较。

DrawerWorld (相较于SOTA相对提升61%~229%):

其中Grid是在训练环境里测试,Black代表背景为纯黑色,其他的测试环境是不同颜色的纹理作为背景的测试环境。相比于SAC以及之前的state-of-the-art PAD,VAI获得了61%~228%的cumulative reward提升。

团队介绍

几名作者均来自于UC Berkeley和ICSI实验室的Vision Group。两位一作中Xudong (Frank) Wang是UC Berkeley EECS系的PhD学生,Long (Tony) Lian是UC Berkeley CS专业的本科生。Stella X. Yu是ICSI Vision Group的主任。

更多实验结果和方法描述,请参照原文。

论文地址:

CVPR官网链接:

https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Unsupervised_Visual_Attention_and_Invariance_for_Reinforcement_Learning_CVPR_2021_paper.pdf

Arxiv: https://arxiv.org/abs/2104.02921

论文PDF和代码下载

后台回复:VAI即可下载上述论文PDF和代码

后台回复:ICCV2021,即可下载ICCV 2021论文和代码开源的论文合集

后台回复:CVPR2021,即可下载CVPR 2021论文和代码开源的论文合集

CVer-强化学习交流群成立

扫码添加CVer助手,可申请加入CVer-强化学习 微信交流群,方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch和TensorFlow等群。

一定要备注:研究方向+地点+学校/公司+昵称(如强化学习+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

▲长按加小助手微信,进交流群

▲点击上方卡片,关注CVer公众号

整理不易,请点赞和在看

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
【资源说明】 基于强化学习的自动化裁剪CIFAR-10分类任务python源码+项目部署说明(提升模型精度+减少计算量).zip 1、该资源内项目代码都是经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载使用,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能。 目前的强化学习工作很多集中在利用外部环境的反馈训练agent,忽略了模型本身就是一种能够获得反馈的环境。本项目的核心思想是:将模型视为环境,构建附生于模型的 agent ,以辅助模型进一步拟合真实样本。 大多数领域的模型都可以采用这种方式来优化,如cv\多模态等。它至少能够以三种方式工作:1.过滤噪音信息,如删减语音或图像特征;2.进一步丰富表征信息,如高效引用外部信息;3.实现记忆、联想、推理等复杂工作,如构建重要信息的记忆池。 这里推出一款早期完成的裁剪机制transformer版本(后面称为APT),实现了一种更高效的训练模式,能够优化模型指标;此外,可以使用动态图丢弃大量的不必要单元,在指标基本不变的情况下,大幅降低计算量。 该项目希望为大家抛砖引玉。 为什么要做自动剪枝 在具体任务中,往往存在大量毫无价值的信息和过渡性信息,有时不但对任务无益,还会成为噪声。比如:表述会存在冗余/无关片段以及过渡性信息;动物图像识别中,有时候背景无益于辨别动物主体,即使是动物部分图像,也仅有小部分是关键的特征。 以transformer为例,在进行self-attention计算时其复杂度与序列长度平方成正比。长度为10,复杂度为100;长度为9,复杂度为81。 利用强化学习构建agent,能够精准且自动化地动态裁剪已丧失意义部分,甚至能将长序列信息压缩到50-100之内(实验中有从500+的序列长度压缩到个位数的示例),以大幅减少计算量。 ***实验中,发现与裁剪agent联合训练的模型比普通方法训练的模型效果要更好。*** 使用说明 环境 ``` torch numpy tqdm tensorboard ml-collections ``` 下载经过预先​​训练的模型(来自Google官方) 本项目使用的型号:ViT-B_16(您也可以选择其它型号进行测试) ``` imagenet21k pre-train wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz ``` 训练与推理 下载好预训练模型就可以跑了。 ``` 训练 python3 train.py --name cifar10-100_500 --dataset cifar100 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz 推理 python3 infer.py --name cifar10-100_500 --dataset cifar100 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz ``` CIFAR-10和CIFAR-100会自动下载和培训。如果使用其他数据集,您需要自定义data_utils.py。 在裁剪模式的推理过程中,预期您将看到如下格式的输出 关于裁剪器的模型结构设计,本模型中认为如何衡量一个信息单元是否对模型有意义,建立于其自身的信息及它与任务的相关性上。 因此以信息单元本身及它与CLS单元的交互作为agent的输入信息。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值