DistilBERT 论文笔记

DistilBERT是BERT的一个轻量化版本,通过模型蒸馏技术,旨在保持高性能的同时减小模型大小和加快推理速度。文章详细介绍了蒸馏过程、DistilBERT的结构、初始化方法以及损失函数的设计,包括MLM、CE和Cos损失。实验结果显示,DistilBERT在精度上仅轻微下降,但在速度和参数量上有显著提升。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

单位:HuggingFace
时间:2020.5
发表:NIPS2019
论文链接:https://arxiv.org/pdf/1910.01108.pdf

一、背景

1. 什么是distill(蒸馏)?

蒸馏简单的说是将大模型(teacher)的学习结果,作为小模型(student)的学习目标,意在小模型能学习到大模型的表示。

蒸馏这个方法的核心思想是:好模型的目标不是拟合训练数据而是学习如何泛化到新的数据

所以蒸馏的目标是让学生模型学习到教师模型的泛化能力,理论上得到的结果会比单纯拟合训练数据的学生模型要好。

2. BERT有哪些短板?

从应用落地的角度来说,bert虽然效果好,但有一个短板就是预训练模型太大,预测时间在平均在300ms以上(一条数据),无法满足线上并发量要求高的业务需求。

二、DistilBERT, a distilled version of BERT

1. 作者的思路

之前的模型蒸馏本质上都是两个loss,即distillation loss和student loss

这样模型学到的都是精调后的知识,即模型都是任务相关的,作者想蒸馏出一个任务无关的BERT,这样通用性更强,在具体任务时做具体的精调即可。

2. 具体做法

I. 模型结构

教师模型采用预训练好的BERT-base,学生模型则是6层的transformer。

II. 学生模型初始化方法

采用了BERT-PKD提出的PKD-skip的方式进行初始化,即用BERT-base的第[2,4,6,8,10]层的参数作为学生模型的参数。

III. Loss的设计

损失函数最终有三个,具体为:

  • MLM loss

### DistilBERT 模型介绍 DistilBERT 是 BERT 的压缩版本,通过知识蒸馏技术创建而成[^1]。该模型保留了原始 BERT 大约 97% 的性能,而参数量却减少了 40%,推理速度提高了 60%。 #### 主要特点 - **更高效**:相比完整的 BERT 模型,DistilBERT 需要较少计算资源,在实际应用中表现出更高的效率。 - **保持高精度**:尽管进行了大幅简化,但在多种自然语言处理任务上仍能维持较高的准确性。 - **易于部署**:由于体积较小且运行速度快,更适合于移动设备或其他受限环境下的部署。 ### 使用方法示例 为了展示如何使用 DistilBERT 进行问答任务,下面提供了一个简单的 Python 实现: ```python from transformers import AutoTokenizer, AutoModelForQuestionAnswering import torch # 加载预训练的 DistilBERT QA 模型和分词器 tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased-distilled-squad') model = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad') context = "Transformers are a class of deep learning models that achieve state-of-the-art results on many NLP tasks." question = "What do Transformers achieve?" encodings = tokenizer(context, question, truncation=True, padding=True) input_ids = torch.tensor(encodings['input_ids']).unsqueeze(0) attention_mask = torch.tensor(encodings['attention_mask']).unsqueeze(0) outputs = model(input_ids=input_ids, attention_mask=attention_mask)[^3] start_scores = outputs.start_logits end_scores = outputs.end_logits answer_start = torch.argmax(start_scores) answer_end = torch.argmax(end_scores) + 1 all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) print(tokenizer.convert_tokens_to_string(all_tokens[answer_start:answer_end])) ``` 这段代码展示了如何加载预训练好的 DistilBERT 模型并利用其完成特定上下文中的问题回答过程。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值