一文搞懂GPT

GPT的模型

这个不多说,网上一大堆。

GPT的WPE(positional encoding)和WTE(text encoding)

GPT的论文(Improving Language Understanding by Generative Pre-Training)关于postional encoding的内容如下:

We used learned position embeddings instead of the sinusoidal version proposed in the original work.

也就是说,GPT用的是随机位置编码,而不是原transformer论文中用的正弦编码。

摘抄GPT2的github代码如下,可以看出构建模型时,位置编码是直接创建了一个标准差为0.01的随机矩阵。

        if params["precision"] == "bfloat16":
            wpe = tf.get_variable('wpe', [params["n_ctx"], params["n_embd"]], # Position encoding
                             initializer=tf.random_normal_initializer(stddev=0.01, dtype=tf.bfloat16), dtype=tf.bfloat16)
            wte = tf.get_variable('wte', [params["n_vocab"], params["n_embd"]], # Text encoding
                             initializer=tf.random_normal_initializer(stddev=0.02, dtype=tf.bfloat16), dtype=tf.bfloat16)

        else:
            wpe = tf.get_variable('wpe', [params["n_ctx"], params["n_embd"]], # Position encoding
                                initializer=tf.random_normal_initializer(stddev=0.01))
            wte = tf.get_variable('wte', [params["n_vocab"], params["n_embd"]], # Text encoding
                                initializer=tf.random_normal_initializer(stddev=0.02))
        past_length = 0 if past is None else tf.shape(past)[-2]

        wpe = dropout(wpe, params["embed_dropout"], train)
        wte = dropout(wte, params["embed_dropout"], train)
GPT的数据集和训练方式

GPT是一种大规模语言预训练模型,并且发展到GPT3时,发现了NLP预训练和训练的两个最关键技术,一个是in context learning;一个是chain of thought。

下面针对这两个分别进行分析,以及解析训练时如何体现这两点的。

  1. in context learning

    这是GPT2发现的。GPT2论文的全称是Language Models are Unsupervised Multitask Learners。其核心就是发现了in context learning。也就是从各种普遍文章中预训练,自动学习得到各种专门任务。具体以翻译任务来讲,像GPT1时代包括之前的话,都是给定专门翻译数据集,数据内容是某个语言,其label为要翻译的语言。输入数据内容,用label做groundtruth训练。这种标注成本比较大,而且泛化能力比较弱。GPT2发现,用来预训练的大规模无标注文本(书籍,报纸之类)其实包含了大量的翻译知识。比方说某个小说里面人物的对话包含了一些翻译上的对应(比方说对话模板如下:person1: 翻译一下"apple";person2:“apple"的中文是"苹果”)。那完全可以用这些文章预训练挖掘得到翻译知识,而没必要去标注专门的翻译数据集。除此之外一些QA任务(问答任务)也是可以用这种in context learning。

  2. chain of thought

    这其实和GPT的turing complete证明有关。GPT为transformer的解码模型,解码部分是存在一个停机判断的,这是GPT拥有turing complete的关键。简单说,如果让GPT直接预测结果(结果一般是定长的),而没有中间的推导过程,那么GPT是不具备turing complete的。最简单的例子是让GPT做NP hard的判定问题,输出是“是”或者"否"。输出是定长的,但是算法复杂度是输入规模的NP hard级别,而GPT2的推理复杂度是输入规模的线性级别。如果GPT是turing complete,那不是说NP hard问题都存在线性算法?显然不是。因此需要将问题的解法步骤也做为groundtruth或者说label去训练,让GPT去解码解法步骤,最终给出输出。

下面讲一下in context learning的训练方式

in context learning是一种大规模预训练方式。也即给定输入序列START I1 I2 … In,让GPT去预测输出I1 I2 I3,…,In,END。也即是START token位的输出结果是I1,I1 token为额输出结果I2,依次类推。

GPT巧妙的是这种in context learning完全可以进行token级别的并行!一般人想的是,先输入START I1,去预测I2,然后反向传播训练;然后输入 START I1 I2,去预测I3,然后反向传播训练等等。这种训练方式类似RNN,实在太低效了,训练时间和seq的长度线性相关。但是GPT用了casual mask self attention,I1不会用到I2 I3的信息,因此可以直接输入START I1 I2 … In,然后预测 I1 I2 … IN END。因为训练的时候每个token都没有用到后面的token,因此不会影响解码。

最后讲一下chain of thought的训练方式

很显然chain of thought并不是用于大规模预训练的,他是一种针对特定任务的标注方式。也即是给每个任务标注出推理过程和结果。然后对GPT给定任务输入,去预测推理过程和结果,最终返回结果。因此需要构造一个数据集,然后去微调。

GPT的推理方式

根据之前的推导发现,生成第n+1个token时,完全不需要对前n-1个token再次运行推理,因为前n-1个token完全没有用到第n个token的信息。因此只需要对第n个token进行推理,与前n-1个token的中间结果进行cross attention。因此推理速度非常快,但是需要保存所有的中间激活结果。查看了hugging face的GPT2实现,是否使用历史token的中间结果,由变量use_cache决定。也即是实现中确实利用了这个来进行推理加速。训练时则设置use_cache为false。

其实从这个角度来看GPT其实就是RNN,只不过GPT的状态量是之前的所有计算结果,而RNN会把之前的所有状态量压缩到一个固定size的状态量,且因为这个压缩过程,导致RNN无法并行训练。

理论上来说,GPT对casual mask限制得太死了。在生成第n+1个token时,第0个token明显可以看到且用到第1到n个token的信息。但是GPT限制死了即便这种情况第0个token也不能用到第1到n个token的信息。好处在于,训练的时候,n个token可以完全并行训练;推理的时候,生成第n+1个token时,前n个token没必要再次计算。

GPT的复杂度

最近看了不少efficient transformer的论文。也自己尝试推导了一下一些线性的attention公式。虽然推导到最后发现,得到的公式是一种低秩方法。而且从自己的推导过程中发现,所有的低秩方法很可能都是将QUERY或者某次QUERY到的所有KEY固定到一个常数个数。这是个人猜测,但很可能有个证明方法证明确实是如此。

总而言之在这个过程中,发现很有可能所有降低transformer复杂度的尝试都是徒劳。这涉及到计算复杂度,而且个人认为将来学术界用可计算理论,计算复杂度理论和形式语言理论研究transformer和GPT必定成为潮流。下面以GPT为例说明为什么这样的操作是徒劳的。注意下面的都是个人思考,不保证正确。

很有可能证明如下的定理:

  1. 不存在停机判断的GPT网络计算复杂度必定是 O ( n 2 ) O(n^2) O(n2),或者GPT解码步长为 O ( n 2 ) O(n^2) O(n2)

  2. 存在停机判断的GPT网络计算最高复杂度必定是 O ( n 2 ) O(n^2) O(n2),或者GPT解码步长为 O ( n 2 ) O(n^2) O(n2)

这个停机判断并非是GPT解码的停机判断,单纯是针对输入编码的GPT网络停机判断。

一种可以初步证明这个推断的问题是:

给定输入 x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn,判断 所有 ∣ x j − x i ∣ |x_j - x_i| xjxi是否小于某一常数 α \alpha α,如果是,输出1,否则输出0。

很显然,如果GPT不给出额外的推断token,那么网络的复杂度必为 O ( n 2 ) O(n^2) O(n2)

如果GPT可以给出额外的推断token,那么网络的复杂度可以小于 O ( n 2 ) O(n^2) O(n2),但此时推断步长复杂度必定为 O ( n 2 ) O(n^2) O(n2)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值