知识蒸馏在文本方向上的应用

知识蒸馏在文本方向上的应用

​  完整项目代码在我的GitHub仓库
​  虽然说做文本不像图像对gpu依赖这么高,但是当需要训练一个大模型或者拿这个模型做预测的时候,也是耗费相当多资源的,尤其是BERT出来以后,不管做什么用BERT效果都能提高,万物皆可BERT。

​  然而想要在线上部署应用,大公司倒还可以烧钱玩,毕竟有钱任性,小公司可玩不起,成本可能都远大于效益。这时候,模型压缩的重要性就体现出来了,如果一个小模型能够替代大模型,而这个小模型的效果又和大模型差不多,何乐而不为。

知识蒸馏介绍

在讲知识蒸馏时一定会提到的Geoffrey Hinton开山之作Distilling the Knowledge in a Neural Network当然也是在图像中开的山,下面简单做一个介绍。

​  知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:

1.原始模型训练: 训练"Teacher模型", 它的特点是模型相对复杂,可以由多个分别训练的模型集成而成。

2.精简模型训练: 训练"Student模型", 它是参数量较小、模型结构相对简单的单模型。

模型结构

蒸馏结构图

借用YJango大佬的图,这里我简单解释一下我们怎么构建这个模型
1.训练大模型
  首先我们先对大模型进行训练,得到训练参数保存,这一步在上图中并未体现,上图最左部分是使用第一步训练大模型得到的参数。
2. 计算大模型输出
  训练完大模型之后,我们将计算soft target,不直接计算output的softmax,这一步进行了一个divided by T蒸馏操作。(注:这时候的输入数据可以与训练大模型时的输入不一致,但需要保证与训练小模型时的输入一致)
3. 训练小模型
  小模型的训练包含两部分。
  -soft target loss
  -hard target loss
  通过调节λ的大小来调整两部分损失函数的权重。
5. 小模型预测
​  预测就没什么不同了,按常规方式进行预测。

模型实现

模型基本上是对论文Distilling Task-Specific Knowledge from BERT into Simple Neural Networks的复现,下面介绍部分代码实现

代码结构

Teacher模型:BERT模型

Student模型:一层的biLSTM

LOSS函数:交叉熵 、MSE LOSS

知识函数:用最后一层的softmax前的logits作为知识表示

学生模型输入

​  Student模型的输入句向量由句中每一个词向量求和取平均得到,词向量为预训练好的300维中文向量,训练数据集为Wikipedia_zh中文维基百科。

w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.wiki.word')
# 生成句向量
def build_sentence_vector(sentence,w2v_model):

    sen_vec = [0]*300
    count = 0
    for word in sentence:
        try:
            sen_vec += w2v_model[word]
            count += 1
        except KeyError:
            continue
    if count != 0:
        sen_vec /= count
    return sen_vec

学生模型结构

​  学生模型为单层biLSTM,再接一层全连接。

class biLSTM(nn.Module):
    def __init__(self):
        super(biLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=300, hidden_size=256,
                         num_layers=1, batch_first=True, dropout=0, bidirectional= True)
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x, hidden=None):
        lstm_out, hidden = self.lstm(x, hidden)     
        out = self.fc1(lstm_out)
        activated_t = F.relu(out)
        linear_out = self.fc2(activated_t)

        return linear_out, hidden

教师模型结构

​  教师模型为BERT,并对最后四层进行微调,后面也接了一层全连接。

class Model(nn.Module):

    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path)
        for param in list(self.bert.parameters())[:-4]:
            param.requires_grad = False
        self.fc = nn.Linear(config.hidden_size, 192)
        # self.fc1 = nn.Linear(192, 48)
        self.fc2 = nn.Linear(192, config.num_classes)

    def forward(self, x):
        context = x[0]  # 输入的句子
        mask = x[2]  # 对padding部分进行mask
        _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers= False)
        out = self.fc(pooled)
        out = F.relu(out)
        # out = self.fc1(out)
        out = self.fc2(out)
        return out

损失函数

​  损失函数为学生输出s_logits和教师输出t_logits的MSE损失与学生输出与真实标签的交叉熵。

# 损失函数
def get_loss(t_logits, s_logits, label, a, T):
    loss1 = nn.CrossEntropyLoss()
    loss2 = nn.MSELoss()
    loss = a * loss1(s_logits, label) + (1 - a) * loss2(t_logits, s_logits)
    return loss

模型效果

Teacher

Running time: 116.05915258956909 s

precisionrecallF1-scoresupport
00.910.840.872168
10.820.900.861833
accuracy0.864001
macro avg0.860.870.864001
weight avg0.870.860.864001

Student

Running time: 0.155623197555542 s

precisionrecallF1-scoresupport
00.870.850.862168
10.830.850.841833
accuracy0.854001
macro avg0.850.850.854001
weight avg0.850.850.854001

​  可以看出student模型与teacher模型相比精度有一定的丢失,这也可以理解,毕竟student模型结构简单。而在运行时间上大模型是小模型的746倍(cpu)。

TNEWS测试效果

在数据集中选了5类并做了下采样。(此部分具体说明后续完善)

Student alone

precisionrecallF1-scoresupport
story0.64890.79070.7128215
sports0.76690.78490.7758767
house0.73500.77780.7558378
car0.81620.75220.7829791
game0.73190.70410.7177659
accuracy0.75622810
macro avg0.73980.76190.74902810
weight avg0.75920.75620.75672810

Teacher

precisionrecallF1-scoresupport
story0.61590.86510.7195215
sports0.84230.79400.8174767
house0.80300.85190.8267378
car0.88230.78630.8316791
game0.78350.80730.7952659
accuracy0.80822810
macro avg0.78540.82090.79812810
weight avg0.81720.80820.81002810

Student

precisionrecallF1-scoresupport
story0.52070.81860.6365215
sports0.84110.70400.7665767
house0.76780.76980.7688378
car0.81040.74590.7768791
game0.68050.74660.7120659
accuracy0.74342810
macro avg0.72410.75700.73212810
weight avg0.76040.74340.74702810

已知问题

  1. 没有写蒸馏过程,就是divided by T是如何实现蒸馏(其实是懒)
  2. 直接用student小模型训练数据的效果如何,并未做测试。
    在TNEWS数据集上完成测试,并上传了训练代码。
  3. 数据集-量并不是很大,自己也只标注了几千条数据,后续会在CLUE的TNEWS短文本分类数据集上做测试,再出一个对比结果。
    在TNEWS数据集上测试,蒸馏结果与直接用student训练效果并未明显提高,还需后续更多测试。

参考链接

  1. 如何理解soft target这一做法? 知乎 YJango的回答

  2. 【经典简读】知识蒸馏(Knowledge Distillation) 经典之作

  3. Distilling the Knowledge in a Neural Network

  4. Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

  5. Chinese-Word-Vectors

  • 9
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 31
    评论
评论 31
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值