Gato:A Generalist Agent

发表时间:11 Nov 2022 论文链接:https://readpaper.com/pdf-annotate/note?pdfId=4689785740490440705&noteId=2412686834489472512 作者单位:DeepMind

前言:近日,DeepMind发布了文章《A Generalist Agent》(简称Gato),号称是“可以玩600多种任务的通才AI”。Gato能聊天,能给图片做标注,能玩Atari游戏,还能控制真实的机械臂,可以说是把AI主流任务做了个遍。文章一发出来后,不出意外热度很高,各大媒体争相报道,但具体评价却是褒贬不一。赞美的声音说它开创了先河,在一定程度上证明了通用agent的可行性;质疑的声音则表示Gato距离 AGI 还很远,并没有宣传的那么惊艳;那么Gato到底怎么样呢?接下来我会从个人视角详细介绍下这篇研究。

Motivation

首先我们考虑一个问题:实现"通才Agent"的最大难点是什么?

AI领域内常见的方向有:

  • NLP:一般输入和输出都是文本字符。例如机器翻译;

  • 图像领域:一般输入和输出都是图像。例如图像生成;

  • 强化学习:输入可以是图像、矩阵向量等,输出是离散或连续的数值。例如玩游戏、下棋;

(当然也有一些交叉任务,例如图片标注是给定图片,输出相应的文本标注。就是一个很好的将Text和Images联系在一起的研究,感兴趣可以看一下,这里就不展开了)

如今的Agent,在设计时必须规定好输入和输出的形式,后续的所有应用也必须在这个规定的框架内,AI才能正常工作。通过上面的介绍可以看到,不同任务的输入和输出差别很大。如果想让一个Agent可以处理所有的这些任务,就必须将各个任务的输入和输出形式统一起来,才有可能使用一套框架解决这些问题。而如何统一,则是走向“通才Agent”首先要解决的问题。

解决方法

面对包括NLP、图像、RL等领域的不同任务,Gato会将各种形式的输入(图片、文本、连续值向量等)按照一定规则,都转换成一种统一的表示形式,以此将不同任务转移到相同的状态表征空间。在这个表征空间下,Gato将所有任务都看作是序列生成,以此进行统一的自回归监督训练。下面我们来看一些重要的细节。

  1. 输入输出表示形式统一化

我们知道,现实任务的输入可以是文本、图像、数组等多模态的信息。一种统一的想法,是将这些多模态的信息通通转化成某一种输入格式。例如,将文本、图像等都转化成底层的字节(raw byte)。而Gato选择的方法是:将所有信息都转化成token(tokenization)。

这里token是源自NLP中的含义,代表最底层的处理元素,在NLP中通常是单词。例如,针对“Where are you from”这句文本,where\are\you\from是4个token。

至于为什么Gato要选择tokenization,是因为作者想用transformer作为网络结构,而transformer的输入输出是建立在token上的。简而言之,transformer一般会将由token组成的序列输入进网络,然后做很多运算,最终输出一个预测的token序列,并让预测的token序列和真实的token序列不断接近,即监督学习。这里有两点需要注意:1.输入是token序列;2. 需要有真值做监督学习;

具体而言,针对不同格式的输入,分别有不同的处理方式来生成token:

  • Text:文本,直接通过SentencePiece([4])(一种分词工具),将文本切分成词,再将词转换成词表中的index。词表总大小为32000。

  • Image:图像,前文我们提到了,输入必须是序列才行。文本天然是由一个个词组成的序列,而图片并不具备这种优势。为了解决这个问题,Gato使用了ViT([5])中的方法。简单来说,如图1所示,一个图片进来后,首先画格子做切分,每个格子大小为16x16,每个格子就是一个patch。再将这些patches按照顺序排列成一行,就得到了图片对应的token序列。

图1. 图片切分成patches

  • Discrete value:离散值,像Atari中的动作按钮之类的。这里直接使用原始的整型数作为token就行,数值范围映射到 [0, 1024)

  • Continuous value:连续值,速度、力矩之类的状态。首先将 [-1, 1] 离散成1024个区域,然后将浮点数通过mu-law变换映射到 [-1, 1]范围内,并在1024个区域内找到对应落到的区间,即完成浮点数到整型数的转换。整数则可以直接作为token使用。为了不和前文text的32000个token重叠,这里的1024的区域会平行偏移到 [32000, 33024] 范围内。

至此,通过上面的变换,大部分形式的输入(图片、文本、离散\连续值)都可以转变成统一的token序列,在完成多模态信息表示形式统一化的同时,也满足了使用transformer的一个条件。

  1. Embedding

拿到输入token序列后,还需要经过一个embedding处理,来将token转换成特征向量

  • text、离散/连续输入经过tokenization后,token为整型数(即index)。拿整型token去embedding table(可学习)里查找即可找到对应的embedding向量,是一个标准的词嵌入过程。

详细过程如图2所示,①是将一个4维的float状态转换成Nx4的特征矩阵,N代表embedding的特征向量长度,每个float状态都会转换成一个N维的向量。

Token的位置信息经过一定编码(Local Pos Encoding)后会放到特征序列中,算是transformer的标准操作,本文就不详细展开了。

图2. 连续\离散输入处理过程

  • image的话,由于token是二维float矩阵,所以直接用ResNet网络(可学习),对每一个patch token处理后,生成对应的特征向量。

图4展示了Gato的训练过程,Text、Image等信息会转换成token序列并组成Batched input,然后进入Gato网络得到输出,最后用输出值和真值算loss。整个过程还是很简洁的。

实现方式

5. 其他细节

我们来看看Gato中比较值得说的一些细节。

  • 监督样本

前文我们说过,Gato使用transformer有两点需要注意:1. 输入是token序列;2. 需要有真值做监督学习。第1点使用tokenization可以实现,第2点则需要预先有标记好的真值样本。参考图3的loss公式,Gato训练的本质就是让网络预测值不断接近真值,进而提升能力。那么Gato训练时的真值从哪来呢?

在NLP和图像领域,通常有人工预先标注好的数据集,数据集里面就有真值,是人类智慧的结晶,照着学就行了。但是在RL领域,一般是没有这种数据集的。RL通常是要不断和环境交互,在环境中学习,而不是和“数据集交互”,所以RL通常没有标注好的数据集。

为了解决RL没有数据集的问题,Gato采用的方式是:使用很多预先训练好的SOTA Agents生成(s,a,r)数据。也就是说,先训练了很多agents,每个agent在某个任务上很强大。然后用这些厉害的agents在环境中交互,并记录下相应数据,也就得到了“数据集”。后续Gato只要在这个数据集上进行监督学习就可以了。所以从这个角度来说,Gato其实还没有走到RL,而是在模仿厉害的teacher agents。

下图是各个任务和数据集名称,sample weight是一个batch中各个任务样本的比重。

  • 模型结构

Gato采用了标准的transformer结构,没有在这方面做过多工作。整个模型有1.2B参数,24层,embedding向量大小2048,feedforward hidden size 8192。可以看到,这个量级虽然不是很大,但也不是一般人能跑的起来的。至于为什么DeepMind不用更大的模型呢?是因为模型越大前向耗时也就越大,在NLP、图像这种对前向耗时不敏感的领域倒没什么问题,但是对于RL中的控制任务,环境并不会卡住来等待模型返回动作,如果前向时间过长,可能下一个状态都产生了,模型还没有返回上一个状态的动作。所以1.2B应该是在允许延迟条件下的最大模型了。

训练放在了一个16x16 TPU v3上,batch size 512,sequence length 1024,训练4 days。

  • Prompt

Gato是在600多个任务上进行训练的,这么多任务,难免一些任务之间会比较相近,比如有相似的observation、action。这就让Agent很容易产生混淆,导致agent不知道当前要解决什么任务。一个简单的做法是给每个任务加上唯一的one-hot特征,这样就很容易可以完全区分开不同的任务。但是缺点也很明显,one-hot特征无法在新任务上迁移,而且后续扩大任务规模时,one-hot也需要相应扩大,很难重复利用。

为了解决任务难以区分的问题,Gato使用了prompt。Prompt最先出现在NLP领域,属于比较新的研究了。简单来说,就是预先写好很多模板prompt,然后把输入信息按照模板的方式重写一遍,这样输入空间就大大规范化了,不同任务间的差别也更加明显,且具备一定的迁移性,适应zero-shot。Prompt写的越多越好,效果也会越好。

对于NLP文本,prompt比较好生成。但是RL中的控制任务,就比较麻烦了。对于RL,50%的prompt是到达终止状态的轨迹,50%的prompt是在episode中均匀采样。在训练时,一个batch中25%的轨迹会加上对应任务的prompt。这个prompt就像一个condition一样补充进输入,让agent方便知道当前任务的信息。

实验:Gato一共有600多个任务放在一起训练,其中大部分都是RL的控制类任务。图像文本类只有对话生成和图片标注两个任务。值得注意的是,600个左右的控制任务,是分成了几个大类的,如Atari game、Procgen、DM Lab、robot等。Atari和robot任务之间的差别确实很大,但Atari和Procgen,以及Atari内部的几十个游戏之间的差别不是很大,所以Gato虽然一共使用了600多个任务,但实际解决的问题数目还是要打个折扣的。

Robotics-RGB Stacking:

输入RGB图像,目标是控制机械臂,将物体摆出相应的形状(如图11上)。在这个实验中,Gato额外和BC-IMP进行了对比,其中BC-IMP只在该单一任务上进行了训练。从图11的实验结果中可以看到,在所有任务上训练的Gato,与只在该任务上训练的BC-IMP,在面对测试集任务时取得了相近的效果。说明Gato泛化性不错。

图11. RGB Stacking

Simulated control tasks:

下图是simulated control tasks,也就是RL环境的实验结果。横轴代表“能力”,0是random agent,100是最强的expert agent;纵轴代表Gato在多少个任务上达到了横轴对应的能力;例如,在60处画一条竖线,代表达到60% expert水平的任务有450个左右。可以看到,随着横轴的提升,Gato的纵轴数目也在下降,能完全达到expert水平的任务只有200个左右。从这一点上来看,Gato够通用,但能力还有提升空间。

图9. 控制类环境实验结果

Image caption and Chitchat:

下图分别是图片标注和聊天的结果展示。其中图片标注的结果还可以,聊天的话就不是很理想了。而且这类具备数据集的任务,是可以打榜和其他工作对比的。但是Gato并没有这么做,也没有在文中介绍很多这两个任务的实验结果,个人猜测可能是打榜分数并不理想。

结论:文介绍了Gato的工作原理、实现细节、和实验结果。Gato将NLP,Image,和RL领域在一定程度上进行了整合,将多模态输入统一为token序列,并把各种类型任务转化成统一的序列生成,使用一个transformer模型完成了600多个不同的任务,Gato的实验结果也证明了该统一思路的可行性。

消融实验结论:

  • 模型越大效果越好(具体参数量见图7)

  • 模型越大泛化性越好;Fine-tuning可以提升效果。

  • 单任务训练结果比多任务效果好(有点反直觉)。只在Atari上训练的specialist Gato,要比在全部任务上训练的generalist Gato效果好。说明增大了任务规模后效果确实会降低,当然这个问题有可能通过增大模型来解决,毕竟在NLP里有过成功案例了。

个人觉得Gato的优势在于:

  1. 第一个尝试将NLP、Image、RL等多个领域整合在一起的工作,具有一定的开创意义;

  2. 使用一个模型框架,一套参数,在600多个不同领域的任务上都取得了还不错的成果,以往工作从来没有一个模型可以完成如此多的任务;

劣势:

  1. Gato虽然在RL环境中进行了实验,也取得了不错的效果,但本质上是监督学习,训练时也没有和环境进行交互,并没有完全把RL整合进框架中,因此Gato在RL领域还不能直接应用;

  2. 模型比较大,且用了很多数据集,一般人还跑不起。这也是被其他人诟病的一点,担心“一切只关乎规模”;

  3. 虽然在600多个任务上进行了实验,但是大部分任务时RL领域,且存在一定的重复性,并不是完全独立的600多个任务。NLP和Image的任务很少(只看到2个),偏科很严重,“通才性”有一些不足

  4. 实验结果令人惊艳的点在于任务数量巨大,而不是效果超凡。当然可能补充更多的数据,训练更多的时间会有提升吧。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ming_Chens

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值