广告行业中那些趣事系列67:基于知识蒸馏的在线分类模型

导读:本文是“数据拾光者”专栏的第六十六篇文章,这个系列将介绍在广告行业中自然语言处理和推荐系统实践。本篇主要介绍基于知识蒸馏的在线分类模型,对于希望将蒸馏技术应用到实际工作的小伙伴可能有帮助。

欢迎转载,转载请注明出处以及链接,更多关于自然语言处理、推荐系统优质内容请关注如下频道。
知乎专栏:数据拾光者
公众号:数据拾光者

下面主要按照如下思维导图进行学习分享:

7b6b601025c1b51fb63b01d734d1c6e0.jpeg

什么是知识、什么是蒸馏?首先,蒸馏本质是迁移学习,即一个简单模型去学习复杂模型的知识,这里的知识可以多种多样,在此举个最简单的例子——知识就是神经网络输出的logits,蒸馏就是指将复杂模型的知识迁移到简单模型,可能大家会有疑问:迁移学习就迁移学习,为啥要叫蒸馏?先看下面公式(1),z就是前面提到的logits,T是温度,q为概率分布。也就是说,简单模型如果想学习logits,需要先经过一个“有温度的”softmax得到概率分布——soft label,然后让简单模型去学习这个soft label,现在跟温度扯上关系了,所以把该迁移过程叫蒸馏。至于为什么要用有温度的softmax,这是因为logits本身是一种暗知识——信息熵很高,简单模型不易学习,所以softmax相当于一种暗知识提取方法,让知识变得简单,至于温度,就是在控制信息熵的大小(温度趋近正无穷时相当于直接学习logits)。

e17af2b905ff4b3068463d39a6fcae4b.jpeg

关于知识蒸馏的基础知识,推荐广告行业中那些趣事系列50:一文看懂BERT知识蒸馏发展进程,如果想对蒸馏有个更全面的了解,推荐综述文献《Knowledge Distillation: A Survey》和Hinton开山之作《Distilling the Knowledge in a Neural Network》,包括一些工程价值较高(对作者启发较大)的论文《Big Self-Supervised Models are Strong Semi-Supervised Learners》

1 项目背景

1.1 输入输出

1.1.1 业务

如下图所示,用户在OPPO浏览器搜索框输入关键词时,在搜索框下方会实时跳出APP的广告,从输入输出角度来看,其本质就是“query to ad”任务。

945dc35f0ce819eb31c09e3e2d8f4dfe.jpeg

1.1.2 在线分类器

更进一步,如下图所示,任务范围缩小到“query to category”(后文称category为类目)

6d19a49f9153fb72f7efafba231ea35b.jpeg

1.1.3 蒸馏框架

从下图可以看到,在线分类器就是我们蒸馏框架的输出,输入为query和其对应的teacher logits

7f65f6c6116acc112fc40c851041de94.jpeg

1.2 为何要蒸馏

也可以换句话问:“为什么不直接使用teacher模型分类”,简言之就是teacher模型太大了,导致推理比较慢,超过线上时延上限。所以我们要同时开展“底层框架优化”和“蒸馏”,前者目的是让teacher直接上线,但过程可能比较曲折缓慢,相比之下“蒸馏”实现简单、开发周期较短,秉着“利益最大化”的原则(前提是保证效果),我们选择前期重点攻关“蒸馏”。当然,在攻关的真空期,teacher也不会闲着,仍然要发挥他最大的作用——行业词包召回,如下图所示,我们使用了历史query来训练teacher分类器,并在第T天预测T-1分区的query,然后将这些带标签的query更新到行业词包,所以这是一个“T+1离线任务”,当然这样做的重要依据是:部分类目的query热门词带来的消耗是比较高的,也就是说今天有人搜索了“股票”,明天大概率也有人搜索“股票”,所以行业词包的方式是可以覆盖这部分词。然而,像“万华化学股价有点高,现在可以入场吗”这种长尾query,行业词包就很难实时覆盖到了,所以我们的蒸馏模型目标是:召回teacher无法实时召回的长尾词

e1e6651540155eca0717e42bb881061b.jpeg

2 算法选型

在算法选型上,作者经历了tinybert -> textcnn -> tinybert,看上去像原地绕了一圈,其实不然,每步都是有理可依的,都有其价值所在,接下来本文会重点介绍我们选型的依据。

2.1 teacher结构

网上大量论文和文章中所指的teacher一般都是比较常见的结构,比如bert二/多分类,假如我们现在的有一个异形结构的teacher,论文和github源码能直接套用吗?所以,蒸馏前首先要对teacher模型有个深入的理解,下图是我们teacher分类器的结构图,可以看到相比于业界常见方案,我们采用了“公共encoder+独立prediction layer”的结构来实现“multi-label classifier”,至于这样做的原因,是综合考虑了项目环境、性能、模型迭代策略等因素,本系列其他文章有详细介绍,这里不再展开。

5d82164739632bdccf7f4a2d80ba051c.jpeg

2.2 算法调研

2.2.1 FastBert

我们对比了多种蒸馏算法,从论文实验效果为依据,挑选了三种候选方案:fastbert、tinybert和bert2textcnn。其中又首先放弃了fastbert,其原因如下图所示:fastbert会在每层encoder后蒸馏出一个分类器用于筛选样本,论文中每层使用的分类器属于“multi-class”,而我们teacher属于“multi-label”,所以每层student的设计是个难题,而且对于“multi label”而言,并不能仅通过一次判断就能筛选样本,所以筛选样本的策略也需要自己设计,就仅从实现难度来看,fastbert是不合适的。

66e100a80633bcd0afb256883dd3f0c6.jpeg

2.2.2 TinyBert

有现成pytorch开源代码,而且有几个版本的蒸馏模型,我们在官方蒸馏模型上用自己训练集微调了一下,发现效果很差,由于项目时间节点紧迫,所以效果差的原因并没有深究(其实回过头来看原因还是挺明显的:没有用query进行transformer蒸馏+训练语料太少,后面会详细讲到,此处略过)。既然不能直接用,那么接下来先看看简单的textcnn

7a156622bbf5f461fa729193a3350e11.jpeg

2.2.3 TextCNN

可以看到textcnn结构非常简单,可以立即实现,而且也已有人验证了bert2textcnn的有效性,再则就是前期会做大量探索性的实验(比如超参、语料、性能),一个小而简单的模型会显著减少实验代价,综上可以确定textcnn是最适合现阶段的方案。

ff46b3df357145c9dd65df42bd524ff4.jpeg

3 模型迭代优化

3.1 二元分类器(textcnn)

c031785adde7c4a0fa23f296742e7ade.jpeg

初版的主要目的是快速验证bert2textcnn蒸馏的有效性,所以暂先实现“二元分类器蒸馏”,等验证有效,再推广到“multi-label蒸馏”。上图展示的是二元分类器的蒸馏框架,与网上“蒸馏”资料描述的一样,并无特别之处,但作者从中获得了一些重要经验,为今后更复杂的实验打好了基础,这里的“重要经验”指的是:

  • 确定了蒸馏代码的正确性

  • 确定了蒸馏损失函数为交叉熵,理论上用KL散度是一样的,但作者观察其loss时,发现会有负数出现

  • 找到一套合适的超参

  • 确定全部使用无标注数据蒸馏——distillation loss+student loss的效果并不明显,而且语料集会受label限制,导致量级不够,实验表明,完全使用无标注数据,效果有显著提升,甚至可以超过teacher。这里个人理解是,当teacher有充足的训练样本时,student基本不会超过teacher,但若teacher训练样本不充分,那么student可以从大量的无标注样本的知识中得到增益。可能读者会有疑问:Hinton在其蒸馏开山之作中提到——若将真实label及其loss带入蒸馏,其效果有明显提升,为什么我这里就放弃了?其实作者也一直抱有这个疑问,直到看到Hinton的最新一篇论文后,才打消了疑虑:

Distillation typically involves both a distillation loss that encourages the student to match a teacher and an ordinary supervised cross-entropy loss on the labels (Eq. 3). In Table 2, we demonstrate the importance of using unlabeled examples when training with the distillation loss. Furthermore, using the distillation loss alone (Eq. 2) works almost as well as balancing distillation and label losses (Eq. 3) when the labeled fraction is small.
Big Self-Supervised Models are Strong Semi-Supervised Learners

可以看到,当你的带label数据在总的样本的占比很小,那么全部使用无标注数据蒸馏和加入带label的,是区别不大的。不过从工程实现角度来讲,两种方式代码开发量相差不大,无妨把label也加上,只不过在加权求和时,注意除以温度平方,如下:

T2 = temperature * temperature
total_loss = student_loss_w * student_loss + distill_loss_w * distill_loss / T2

3.2 multi-label分类器(textcnn)

由binary蒸馏框架过渡到multi-label是很自然的,如下图所示,我们可以将学生设计成跟老师一样的结构——公用encoder+各自的prediction layer,每个student各自学习自己teacher的知识,然后将各类loss加起来。不过,维护每个类目的prediction layer是很麻烦的(类目越多越麻烦),所以还需要从工程的角度对蒸馏框架进行一些优化:1)将所有teacher的softmax拼接成一个向量,注意这里是多个2分类的softmax,有别于多分类的softmax;2)同样的构建student softmax;3)求loss时,先对softmax向量reshape,然后矩阵计算各类的交叉熵,最后reduce_sum得到最终loss。优化后,不仅只用维护一套参数,而且实现了各类并行蒸馏。实际上我们还可以将此框架继续抽象:隐去具体的teacher结构,蒸馏框架只保证student与teacher知识对齐,而不管这些知识怎么来的,所以这里的softmax不再局限于是1x2,也可以是1xM(可混合multi-class),而且也不局限于共用同一encoder。确定框架后,我们将会面临multi-label带来的新问题:

  • 用什么策略来评估multi-label模型

  • 不仅存在类内数据不平衡,类间也存在不平衡,student共享参数部分会不会受多数类的影响大,而少数类的影响小?要怎样获取语料,来缓解这些潜在影响?

fd4108207977e8d2d9d843206288a433.jpeg

评估:首先,我们的任务是分类,所以需要先确定每个类的阈值,然后根据该阈值算出precision、recall、f1,那么如何确定每个类的阈值呢?我们的思路是,从实际需求出发,既然模型是服务于在线分类,那么我们直接用贴近业务的指标来确定阈值——在precision=0.8条件下(控制badcase率),使recall最大(消耗最大化)的阈值, 然后对各阈值对应的recall或f1求加权平均,其权重可理解为从业务角度来看某类的重要性。

数据不平衡:虽然我们完全采用了无标注数据,但事实上,他们在类目维度上也是不平衡的,只是我们没法直观去感受这种不平衡。针对这个问题,作者借鉴自训练的思想,让teacher模型帮忙筛选一批语料,具体步骤为:1)让teacher去预测一批语料,得到其伪标签和概率值;2)按伪标签分层采样——伪标签分组并倒排概率值,然后每组内取topK(如果不信任某类目的teacher,可在该组内随机采样),需要注意的是topK不宜取得太多,因为类目A/B/C的正样本对于类目D来说是负样本,如果语料真实分布中的D正样本本来就很少,甚至取不够K个,那么其正负样本比是很不平衡的;3)随机采一些负样本(未被打上标签的或“三俗”类标签)。正如前面所说,这种方法只是暂时缓解但未能消除数据不平衡,其原因是:首先此方法依赖teacher的准确度,所以最终取的每个类的正样本与真实正样本是存在误差的,再退一步说,假设teacher准确度100%,而且每个类都是平衡的,那么当我们的类目达到100时,对于每个类目来说,其正负样本比为1:99,所以随着类目的增加,最终还是会不可避免地遇到不平衡问题,但至少从目前的实验来看(类目数为44),这种方法仍然适用。

3.3 细节优化

大框架确定下来后,接下来就是一些细节的优化,虽然预期这些细节优化的增益不会很大,但在暂时没有找到其他大的优化点时,是可以花点时间来做尝试的,毕竟蚊子再小也是肉。

3.3.1 Pretrained Word Embedding

textcnn的词嵌入层参数是随机初始化的,有论文表示,若使用Pretrained Word Embeddings来初始化词嵌入层,模型会有提升,所以作者尝试了如下w2v:中文wiki/人民日报/搜狗新闻/知乎问答/微博/Mixed-large/google-w2v/bert-embedding,以及他们的融合,实验结果为在数据集A上有较明显的提升,但在更丰富的数据集B上效果就不明显了,甚至还会下降,所以在训练语料匮乏的情况下可以尝试下Pretrained Word Embeddings,但如果语料充足的话,可以不考虑此方法。

3.3.2 温度

At lower temperatures, distillation pays much less attention to matching logits that are much more negative than the average. This is potentially advantageous because these logits are almost completely unconstrained by the cost function used for training the cumbersome model so they could be very noisy. On the other hand, the very negative logits may convey useful information about the knowledge acquired by the cumbersome model. Which of these effects dominates is an empirical question. We show that when the distilled model is much too small to capture all of the knowledege in the cumbersome model, intermediate temperatures work best which strongly suggests that ignoring the large negative logits can be helpful
Distilling the Knowledge in a Neural Network

按照论文的原话,温度越高teacher输出的概率值越趋近于均匀分布,蕴含的知识越丰富,但若学生模型很小的话,是无法掌握这些知识的,所以建议使用一个温和的温度。再结合我们实验来看,textcnn温度在1~1.5的时候效果最好,所以做蒸馏实验时,一般保持默认温度T=1.0就可以了,前期花时间去优化温度意义不大。

3.4 TinyBert

3.4.1 重拾TinyBert

目前为止textcnn取得的效果还是比较可观的,甚至在一些重要的标签上超过了当前teacher,不过还是存在一些明显的问题:1)student始终受限于teacher(虽然可以超过,但涨幅也有限);2)分析一些case发现,student还是无法理解一些语义。针对问题1)我们只能换个“更好的老师”,所以我们也一直在迭代优化teacher,而且如预期所料,基于新teacher的student也会随之提升;针对问题2)我们的蒸馏语料已达1000w,基本足够了,所以我们有理由质疑——问题关键在模型身上——随着类目的增多,要求掌握的知识也就越多,textcnn太简单以至于无法深入的理解某些query。所以,面对这个问题,目光很自然地就投向了tinybert——直接用少层的bert替换textcnn,虽然前面调研阶段也尝试并放弃过tinybert,但实验证明了textcnn都能表现得不错,若其他环节不变,只是换一个更复杂且结构与teacher更相似的student,理论上是应该有明显提升的,所以可以得出这样个结论——之前的tinybert用得太浅,大概率是忽略了一些关键点,所以这次,我们得慢下来深入研究tinybert了!

其实到现阶段,对tinybert的研究已经变得相对简单些了,因为项目初期整个蒸馏框架还未搭建,在实现tinybert时,要定位问题还比较难(这也是当时浅尝辄止的原因,基础建设和知识储备都不够,短时间无法上线),但现在我们蒸馏框架已比较稳定,所以可以将所有精力都投入到模型中,最重要的是,随着知识储备量的增加,我们现在有理由相信——tinybert要比textcnn好,若不是,多半是自身其他环节出了问题。

3.4.2 决定造轮

虽然tinybert有开源代码,但作者最终决定自己造轮,原因如下:

  • 开源蒸馏框架是基于pytorch开发的,而我们目前训练、部署是用的tensorflow,虽然可以用一些模型转换工具来实现跨平台训练、部署,但是毕竟环节越多越容易出问题

  • 我们的蒸馏实验会做各种尝试,难免会改动开源代码,开源代码一般考虑比较周全,所以在我们的需求面前,其代码会显得冗余,维护起来也比较麻烦

  • 最重要的是,tinybert核心代码实现并不复杂,而且我们已经有了符合自身需求的蒸馏框架,从textcnn过渡到tinybert,可以看作多写一个TinyBert类,然后替换掉TextCNN,所以只要把论文几个核心点理解了,实现出来也是水到渠成的事。

3.4.3 实现

eb17518aecddd89c954e7b56de19fb1c.jpeg

如上图所示,tinybert分为两大步蒸馏(我们使用的蒸馏语料都采自线上真实无标注query,量是够的,所以可以略过数据增强):

  • General Learning/Distillation(GD):其实第一步的GD是可以略过的,因为论文作者的消融实验表明,GD带来的增益较小,虽然在实现中GD只是TD的一个环节,代码甚至可以复用,但由于GD要使用更大规模的语料,蒸馏时间成本会显著增加,权衡之下还是弃用了

  • Task-specific Leraining/Distillation(TD):tinybert的核心部分,其transformer蒸馏又涉及到三种Layer:1)Embedding layer;2)Transformer Layer;3)Prediction Layer。需要注意的是,下图用分段函数来表示各层的损失函数,这就意味着,我们不能把所有类型的loss都融合到一起,然后蒸馏一次就完事了,而是要先蒸馏Embedding layer,然后再蒸馏Transformer Layer,最后蒸馏Prediction Layer,一共三步!所以要完成一次TD,时间成本也是比较大的, 好在论文证明了Embedding layer蒸馏的影响较小,所以我们可以省略这一步(实际上官方开源代码也是省略这步的),最终我们要实现的蒸馏步骤仅有:先transformer layer蒸馏,再prediction layer蒸馏。

285b5c02366163f92ac9aa772cab9a74.jpeg 2cddd249fe84776d89188a0e276e531d.jpeg

Transformer Layer蒸馏

这部分没有温度控制,实现也比较简单,关键是要按照论文提出的方式提取知识:

  • 确定学生和老师层级的对应关系

  • attention layer知识为softmax前的attention score

  • transformer输出时,若tinybert的隐层维度与teacher不一致(如teacher为768,student为312),那么需要乘以变量Wh来实现对齐,以下代码中他们维度都是768,所以没有Wh

def create_transformer_distillation_graph(input_ids, tinybert_conf_path, pre_model_dir):
    """    
        :param input_ids: token ids    
        :param tinybert_conf_path: tinybert配置(json)
        :param pre_model_dir: teacher bert 预训练模型路径
        :return: 蒸馏loss    """
    s2t_layer_map = {0: 1, 1: 3, 2: 5, 3: 7, 4: 9, 5: 11}  # 学生与老师隐层映射关系,暂写死    
    logger.info("get teacher's feature based knowledge...")
    t_output_att, t_output_trans = get_feature_based_knowledge(input_ids,
                                                      list(s2t_layer_map.values()), 
                                                      pre_model_dir)
    logger.info("build tiny bert model...")
    bert_config = BertConfig.from_json_file(tinybert_conf_path)
    tinybert = BertModel(config=bert_config,
                         is_training=True,
                         input_ids=input_ids,
                         input_mask=None,                         
                         token_type_ids=None,                         
                         use_one_hot_embeddings=False,
                         scope="tinybert")
    graph = tf.get_default_graph()
    s_output_att = [
        graph \
        .get_operation_by_name(f"tinybert/encoder/layer_{i}/attention/self/add").outputs[0]
        for i in s2t_layer_map]
    s_output_trans = tinybert.all_encoder_layers

    logger.info("build distillation loss...")
    with tf.name_scope("distill_loss"):
        # attention loss
        loss_att_each_layer = [
            tf.compat.v1.losses.mean_squared_error(s_output_att[i], t_output_att[i]) 
            for i in s2t_layer_map]
        loss_att = tf.reduce_sum(loss_att_each_layer)

        # transformer loss,如果学生隐层size跟老师不一致,则需要乘Wh来对齐        
        loss_trans_each_layer = [
            tf.compat.v1.losses.mean_squared_error(s_output_trans[i], t_output_trans[i])
            for i in s2t_layer_map]
        loss_trans = tf.reduce_sum(loss_trans_each_layer)
        loss_distill = loss_att + loss_trans

    return loss_distill

Prediction Layer蒸馏

prediction layer蒸馏跟之前的方式一样,需要注意的是,蒸馏前将上一步transformer蒸馏过的参数加载进来,并且不要冻结参数!

3.4.4 效果

  • 测试集上大部分类目效果超过teacher

  • 回归测试,其结果与测试集基本一致,大部分都是正向优化。

  • 线上指标

下图绿色曲线是“词包(前面介绍的离线teacher)的下载量”,蓝色是“student下载量”。可以看到tinybert上线前,student与teacher的差距还是很明显的,当tinybert上线后,效果立竿见影——差距明显缩小,最终证明了tinybert确实比textcnn好。

185ff7198a0a60ba164c987bb92d76aa.jpeg

4 总结与展望

回望走过的路,略显曲折,不禁会问:为什么我们的思考路径是——放弃tinybert,再重拾tinybert,而不是一开始就咬定tinybert?其实上文已经给出了答案,我们在制定方案时,时刻谨记我们思考路径的底层逻辑是什么——项目需求,项目初期从0到1,要突出一个“快”字,所谓聊胜于无,这时候若花很多时间憋个大招,项目可能就延期了,评估了开发tinybert可能比较慢(至少比textcnn慢),那么放弃他也是合理的;等到现阶段稳定下来后,需求是要攻坚,这个时候可以适当慢下来做些更深入的研究,tinybert自然也就再次被提起。

从目前的指标来看,基于知识蒸馏的在线分类模型可能快到瓶颈了——毕竟他的上限是teacher,所以我们也一直在思考其他场景下的蒸馏:

  • 自蒸馏:各论文中所述的自蒸馏可分为两种:一种是属于self-training范畴的、另一种属于mutual learning范畴,前者跟自训练差不多,唯一区别就是自训练一般使用伪标签训练,而自蒸馏使用知识(soft label)训练;后者指的是找一个跟teacher结构一样的student,然后一起学习(一起更新参数)。作者只尝试过前一种自蒸馏,新的teacher效果有明显的提升,所以teacher在迭代优化的时候,除了自训练、主动学习,自蒸馏是否也可以加入其中

  • 在线蒸馏:在线蒸馏属于mutual learning范畴,论文中提到效果上——teacher和student都会提升,性能上——一次蒸馏就完成了teacher和student的优化,当前蒸馏有个痛点就是训练时间太久,期望在线蒸馏降低时间成本

  • 将带label的训练集利用起来:目前我们仅使用了无标注数据,但有标注的训练集并没有充分利用起来,针对这个问题,作者正在这样做:在蒸馏后的模型上用带标注数据进行fine-tune(仅更新prediction layer),但目前存在的问题是,并非在每个类上都是正向优化。

  • 半监督框架下的自蒸馏:我们正在想办法将自蒸馏融入到我们的半监督流程中,去迭代优化模型。设想的是,在流程最后对新模型再进行一次自蒸馏,但从目前的实验来看,不能保证每个类目都是正向优化。所以作者有以下想法:1)将自蒸馏后正向优化的部分作为teacher,再反向蒸馏到原teacher(只更新原teacher的prediction layer参数),看看原teacher会不会有提升,这个方法看上去可能会有效,但要蒸馏两次,时间成本太高;2)细想一下前面方案,其实可以一次蒸馏就搞定——在线蒸馏,在训练的过程中student也会影响teacher,从论文实验来看,在线蒸馏会使teacher和student都得到提升,所以作者对这个方法抱有很大希望

最新最全的文章请关注我的微信公众号或者知乎专栏:数据拾光者。

码字不易,欢迎小伙伴们点赞和分享。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值