FastBert 样本自适应推理机制 理解

FastBERT: a Self-distilling BERT with Adaptive Inference Time的理解




前言

FastBert的论文地址:https://arxiv.org/pdf/2004.02178.pdfgithub开源代码地址:https://github.com/BitVoyage/FastBERT

为了改善bert的推理时间,文章的作者提出了自蒸馏和自适应机制的FastBert 模型,FastBert的整体思路是:论文中是以文本分类的任务实现bert模型蒸馏的:

(1)微调:bert模型后接Classifier层(Teacher Classifier),实现分类模型的微调,Classifier层的结构:Fully-connect layer(768->128)->bert-self-attention(128->128)->Fully-connect layer(128->128)->Fully-connect layer(768->num_class);
(2)自蒸馏:bert encoder的每一层后接Classifier(Student Classifier)层,冻结微调阶段的参数,Student Classifier层学习Teacher Classifier的分布状况,使用KL散度衡量损失;
(3)推理:每层的Student Classifier分类器对样本进行预测,预测结果使用熵来衡量,熵越大,不确定度越大;分类效果:使用speed,代表不确定性的阈值,与推理速度成正比。speed阈值越大,推理速度越快。对于某些样本可能经过很少的层就能预测出来,而最坏的情况经过所有层预测出结果,因而实现推理阶段减少计算量,推理速度加快

一、FastBert模型

FastBERT,a pre-trained model with a sample-wise adaptive mechanism. It can adjust the number of executed layers dynamically to reduce computational steps.This model also has a unique self-distillation process that requires minimal changes to the structure,achieving faster yet as accurate outcomes within a single framework.
作者提出一种具有样本自适应机制的预训练模型,能够动态的调整计算层的数量从而减少计算量。单个框架内实现模型的自蒸馏过程,对模型结构改变很小,速度更快且有准确的结果。

 FastBert的模型结构图

Backbone主干: the embedding layer, the encoder containing stacks of Transformer blocks and the teacher classifier.(bert模型 + teacher classifier层
Brach分支:指的是所有的Student Classifier

Classifier层的结构:
在这里插入图片描述
Classifier层的代码实现:

class FastBERTClassifier(nn.Module):
    def __init__(self, config, op_config):
        super(FastBERTClassifier, self).__init__()

        cls_hidden_size = op_config["cls_hidden_size"]
        num_attention_heads = op_config['cls_num_attention_heads']
        num_class = op_config["num_class"]

        self.dense_narrow = nn.Linear(config.hidden_size, cls_hidden_size)
        self.selfAttention = BERTSelfAttention(config, hidden_size=cls_hidden_size, num_attention_heads=num_attention_heads)
        self.dense_prelogits = nn.Linear(cls_hidden_size, cls_hidden_size)
        self.dense_logits = nn.Linear(cls_hidden_size, num_class)

    def forward(self, hidden_states):
        states_output = self.dense_narrow(hidden_states)
        states_output = self.selfAttention(states_output, None, use_attention_mask=False)
        token_cls_output = states_output[:, 0]
        prelogits = self.dense_prelogits(token_cls_output)
        logits = self.dense_logits(prelogits)
        return logits

FastBertGraph代码实现:整个代码的微调和蒸馏以及推理阶段的核心逻辑代码

class FastBERTGraph(nn.Module):
    def __init__(self, bert_config, op_config):
        super(FastBERTGraph, self).__init__()
        bert_layer = BERTLayer(bert_config)
        self.layers = nn.ModuleList([copy.deepcopy(bert_layer) for _ in range(bert_config.num_hidden_layers)])    

        self.layer_classifier = FastBERTClassifier(bert_config, op_config)
        # 分类层的参数字典
        self.layer_classifiers = nn.ModuleDict()  
        for i in range(bert_config.num_hidden_layers - 1):
            self.layer_classifiers['branch_classifier_'+str(i)] = copy.deepcopy(self.layer_classifier)
        self.layer_classifiers['final_classifier'] = copy.deepcopy(self.layer_classifier)

        self.ce_loss_fct = nn.CrossEntropyLoss()  # 交叉熵损失
        self.num_class = torch.tensor(op_config["num_class"], dtype=torch.float32)
        
    def forward(self, hidden_states, attention_mask, labels=None, inference=False, inference_speed=0.5, training_stage=0):
        #-----Inference阶段,第i层student不确定性低则动态提前返回----#
        if inference:
            uncertain_infos = [] 
            for i, (layer_module, (k, layer_classifier_module)) in enumerate(zip(self.layers, self.layer_classifiers.items())):
                hidden_states = layer_module(hidden_states, attention_mask)
                logits = layer_classifier_module(hidden_states)
                prob = F.softmax(logits, dim=-1)
                log_prob = F.log_softmax(logits, dim=-1)
                uncertain = torch.sum(prob * log_prob, 1) / (-torch.log(self.num_class))
                uncertain_infos.append([uncertain, prob])

                #提前返回结果
                if uncertain < inference_speed:
                    return prob, i, uncertain_infos
            return prob, i, uncertain_infos
        #------训练阶段, 第一阶段初始训练, 第二阶段蒸馏训练--------#
        else:
            #初始训练,和普通训练一致
            if training_stage == 0:
                for layer_module in self.layers:
                    hidden_states = layer_module(hidden_states, attention_mask)
                logits = self.layer_classifier(hidden_states)
                loss = self.ce_loss_fct(logits, labels)
                return loss, logits
            #蒸馏训练,每层的student和teacher的KL散度作为loss
            else:
                all_encoder_layers = []
                for layer_module in self.layers:
                    hidden_states = layer_module(hidden_states, attention_mask)
                    all_encoder_layers.append(hidden_states)

                all_logits = []
                for encoder_layer, (k, layer_classifier_module) in zip(all_encoder_layers, self.layer_classifiers.items()):
                    layer_logits = layer_classifier_module(encoder_layer)
                    all_logits.append(layer_logits)
                    
                #NOTE:debug if freezed
                #print(self.layer_classifiers['final_classifier'].dense_narrow.weight)

                loss = 0.0
                teacher_log_prob = F.log_softmax(all_logits[-1], dim=-1)
                for student_logits in all_logits[:-1]:
                    student_prob = F.softmax(student_logits, dim=-1)
                    student_log_prob = F.log_softmax(student_logits, dim=-1)
                    uncertain = torch.sum(student_prob * student_log_prob, 1) / (-torch.log(self.num_class))
                    #print('uncertain:', uncertain[0])

                    D_kl = torch.sum(student_prob * (student_log_prob - teacher_log_prob), 1)
                    D_kl = torch.mean(D_kl)
                    loss += D_kl 
                return loss, all_logits

1.初始的训练阶段

for layer_module in self.layers:
    hidden_states = layer_module(hidden_states, attention_mask)
logits = self.layer_classifier(hidden_states)
loss = self.ce_loss_fct(logits, labels)
return loss, logits

训练阶段就是bert模型后接分类层进行微调,与之前的模型微调一样
2.自蒸馏阶段

all_encoder_layers = []  # 所有的encoder层
for layer_module in self.layers:
    hidden_states = layer_module(hidden_states, attention_mask)
    all_encoder_layers.append(hidden_states)

all_logits = []  # Classifier层的概率分布
for encoder_layer, (k, layer_classifier_module) in zip(all_encoder_layers, self.layer_classifiers.items()):
    layer_logits = layer_classifier_module(encoder_layer)
    all_logits.append(layer_logits)
#NOTE:debug if freezed
#print(self.layer_classifiers['final_classifier'].dense_narrow.weight)
loss = 0.0
# Teacher classifier层的概率分布
teacher_log_prob = F.log_softmax(all_logits[-1], dim=-1)
for student_logits in all_logits[:-1]:
    student_prob = F.softmax(student_logits, dim=-1)
    student_log_prob = F.log_softmax(student_logits, dim=-1)
    # 不确定度的计算
    uncertain = torch.sum(student_prob * student_log_prob, 1) / (-torch.log(self.num_class))
    #print('uncertain:', uncertain[0])
	#KL散度衡量Studnet classifier层和Teacher classifier层的分布差异
    D_kl = torch.sum(student_prob * (student_log_prob - teacher_log_prob), 1)
    D_kl = torch.mean(D_kl)
    loss += D_kl 
return loss, all_logits**加粗样式**

冻结backbone的参数,学习student classifier的参数分布

model = FastBertModel(bert_config, config)
load_saved_model(model, args.save_model_path)
save_model_path_for_train = args.save_model_path_distill

#Freeze Part Model  冻结微调阶段的参数加粗样式
for name, p in model.named_parameters():
    if "branch_classifier" not in name:
        p.requires_grad = False

KL散度的计算公式
在这里插入图片描述
3.推理阶段代码

uncertain_infos = [] 
for i, (layer_module, (k, layer_classifier_module)) in enumerate(zip(self.layers, self.layer_classifiers.items())):
    hidden_states = layer_module(hidden_states, attention_mask)
    logits = layer_classifier_module(hidden_states)
    # Classifier层的分类输出概率
    prob = F.softmax(logits, dim=-1)
    log_prob = F.log_softmax(logits, dim=-1)  # 对数概率
    # 不确定度
    uncertain = torch.sum(prob * log_prob, 1) / (-torch.log(self.num_class))  
    uncertain_infos.append([uncertain, prob])

    #提前返回结果  不确定度的判断
    if uncertain < inference_speed:
        return prob, i, uncertain_infos
return prob, i, uncertain_infos

不确定度的公式:
在这里插入图片描述

二、FastBert模型的训练

模型的代码实现参考github:https://github.com/BitVoyage/FastBERT
模型的训练逻辑

1. load 预训练好的bert,分类模型的finetune
2. freeze主干网络和最后层的teacher分类器,每层的子模型拟合teacher分类器(KL散度为loss)
3. inference阶段,根据样本输入,子分类器置信度高则提前返回

训练过程代码逻辑:

    if args.run_mode == 'train':
        #初始训练  微调阶段
        if args.train_stage == 0:
            model = FastBertModel.load_pretrained_bert_model(bert_config, config,
                        pretrained_model_path=config.get("bert_pretrained_model_path"))
            save_model_path_for_train = args.save_model_path
        #蒸馏训练
        elif args.train_stage == 1:
            model = FastBertModel(bert_config, config)
            load_saved_model(model, args.save_model_path)
            save_model_path_for_train = args.save_model_path_distill

            #Freeze Part Model  冻结为微调阶段backbone的网络参数
            for name, p in model.named_parameters():
                if "branch_classifier" not in name:
                    p.requires_grad = False
            logging.info("Main Graph and Teacher Classifier Freezed, Student Classifier will Distilling")
        else:
            raise RuntimeError('Operation Train Stage(0 or 1) not Legal')

    elif args.run_mode == 'eval':
        model = FastBertModel(bert_config, config)
        load_saved_model(model, args.save_model_path)
    else:
        raise RuntimeError('Operation Mode not Legal')    

其中ps指的输出的概率分布,N为分类标签数量
代码测试流程

    1. 初始训练:
    sh run_scripts/script_train_stage0.sh
    
    2. 蒸馏训练:
    sh run_scripts/script_train_stage1.sh
    **注意** :蒸馏阶段输入数据为无监督数据,可依据需要引入更多数据提升鲁棒性

    3. 推理:
    sh run_scripts/script_infer.sh
    其中 inference_speed参数(0.0~1.0)控制加速程度
    
    4. 部署使用
    python3 predict.py

源码实测的结果为:
测试设备GPU,单张显卡
在这里插入图片描述

测试集上的infer
ChnSentiCorp:
speed_arg:0.0, time_per_record:0.0364, acc:0.9533,   基准
speed_arg:0.1, time_per_record:0.0330, acc:0.9533,   1.10倍
speed_arg:0.5, time_per_record:0.0223, acc:0.9383,   1.63倍
speed_arg:0.8, time_per_record:0.0171, acc:0.9108,   2.13倍
验证集上infer
ChnSentiCorp:
speed_arg:0.0, time_per_record:0.0365, acc:0.9392,   基准
speed_arg:0.1, time_per_record:0.0332, acc:0.9400,   1.10倍
speed_arg:0.5, time_per_record:0.0237, acc:0.9333,   1.54倍
speed_arg:0.8, time_per_record:0.0176, acc:0.9100,   2.07

测试设备GPU,单张显卡:同样步骤重新训练的模型测试
在这里插入图片描述

验证集上infer
ChnSentiCorp:
speed_arg:0.0, time_per_record:0.0356, acc:0.9425,   基准
speed_arg:0.1, time_per_record:0.0263, acc:0.9433,   1.35倍
speed_arg:0.5, time_per_record:0.0186, acc:0.9383,   1.91倍
speed_arg:0.8, time_per_record:0.0157, acc:0.9050,   2.26

acc的数值与github作者的相似,但是推理速度差异较大,可能与设备有关。
github提供的测试结果:

ChnSentiCorp:
speed_arg:0.0, time_per_record:0.14725032741416672, acc:0.9400,   基准
speed_arg:0.1, time_per_record:0.10302954971909761, acc:0.9420,   1.42倍
speed_arg:0.5, time_per_record:0.03420266199111938, acc:0.9340,   4.29倍
speed_arg:0.8, time_per_record:0.019530397139952513, acc:0.9160,  7.54倍
注:speed=0.1的情况下比基准的准确率还高,是有可能的,正则之类的效应

从作者实测的结果可以看出,蒸馏之后的模型,speed为0.1的时候推理速度提升1.42倍,acc略微提升一点,speed为0.5的时,推理速度提升4.29倍,acc下降0.006个点,随着speed的提高,推理速度加快,acc也会降低。

三、FastBert论文中的测试结果

在这里插入图片描述
论文中在6个中文和6个英文语料集上均进行了测试,结果显示:acc基本不变的情况下,speed=0.1时,FastBert有2~5倍的推理速度提升,acc有较小损失的情况下, FastBERT相比 BERT有7到11倍的推理速度提升。

本人在几种数据集上的测试结果,按照github提供的代码,没有修改参数,GPU GTX1080TI上的测试结果:
ChnSentiCorp数据集

         acc     loss
speed_arg 0.0  time_per_record 0.0286 acc 0.9425  基准
speed_arg 0.1  time_per_record 0.0286 acc 0.9433  1.0
speed_arg 0.5  time_per_record 0.0178 acc 0.9383  1.6
speed_arg 0.8  time_per_record 0.0136 acc 0.9050  2.1
speed_arg 1    time_per_record 0.0049 acc 0.5025  5.8

thnews数据集

         acc     loss
speed 0   time_pre_record 0.0334 acc 0.9704  基准
speed 0.1 time_pre_record 0.0303 acc 0.9704  1.1
speed 0.5 time_pre_record 0.0268 acc 0.9704  1.2
speed 0.8 time_pre_record 0.0126 acc 0.9696  2.7
speed 1   time_pre_record 0.0070 acc 0.8930  4.8

微博数据集

         acc     loss
speed 0   time_pre_record 0.0208 acc 0.9788  基准
speed 0.1 time_pre_record 0.0280 acc 0.9788  0.7
speed 0.5 time_pre_record 0.0065 acc 0.9783  3.2
speed 0.8 time_pre_record 0.0058 acc 0.9774  3.6
speed 1   time_pre_record 0.0045 acc 0.9680  6.2

豆瓣数据集

         acc     loss
speed 0   time_pre_record 0.0283 acc 0.8796  基准
speed 0.1 time_pre_record 0.0240 acc 0.8808  1.2
speed 0.5 time_pre_record 0.0144 acc 0.8693  2.0
speed 1   time_pre_record 0.0045 acc 0.7100  6.3

购物数据

         acc     loss
speed 0   time_pre_record 0.0299 acc 0.9697  基准
speed 0.1 time_pre_record 0.0157 acc 0.9694  1.9
speed 0.5 time_pre_record 0.0071 acc 0.9602  4.2
speed 1   time_pre_record 0.0045 acc 0.9202  6.6

从测试结果可以看出,不同的中文数据集的蒸馏后推理速度的加速差异较大,同一数据集多次测试推理速度也会有一定的波动,可能与当前设备的状况有关。
测试数据集,下载地址:https://share.weiyun.com/ZctQJP8h或者百度网盘链接:https://pan.baidu.com/s/1MayPlDKJ0_UEZgrdp_1MLg
提取码:x6fw

总结

FastBert在多个数据集上测试结果显示:在很小精度损失的情况下,模型的推理速度均有一定程度的提升。

参考

知乎文章:https://zhuanlan.zhihu.com/p/127869267
github:https://github.com/autoliuweijie/FastBERT

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
Criminisi等人提出的自适应样本块大小算法可以用MATLAB实现,具体步骤如下: 1. 加载图像并将其转换为灰度图像。 2. 初始化掩模(mask)和填充图(fillImg)。 3. 在掩模中选择第一个未标记的像素p。 4. 计算p周围的样本块大小n(可以根据需要选择不同的n值)。 5. 从p周围的n×n像素块中选择最相似的块,并将其复制到填充图上。 6. 更新掩模和填充图,标记已填充的像素。 7. 重复步骤3-6,直到所有像素都被填充。 以下是MATLAB代码示例: ```matlab % 加载图像并将其转换为灰度图像 img = imread('input.png'); grayImg = rgb2gray(img); % 初始化掩模和填充图 mask = zeros(size(grayImg)); fillImg = grayImg; % 设置样本块大小 n = 9; % 循环处理未标记的像素 while any(mask(:) == 0) % 选择下一个未标记的像素 [row, col] = find(mask == 0, 1); % 计算样本块大小 r1 = max(1, row - n); r2 = min(size(grayImg, 1), row + n); c1 = max(1, col - n); c2 = min(size(grayImg, 2), col + n); % 选择最相似的块 patch = fillImg(r1:r2, c1:c2); targetPatch = grayImg(row-n:row+n, col-n:col+n); ssd = sum(sum((patch - targetPatch).^2)); [minRow, minCol] = find(ssd == min(ssd(:))); minRow = minRow(1) + r1 - 1; minCol = minCol(1) + c1 - 1; % 将最相似的块复制到填充图上 fillImg(row, col) = fillImg(minRow, minCol); % 更新掩模和填充图 mask(row, col) = 1; end % 显示结果图像 imshow(fillImg); ``` 请注意,这只是一个简单的示例代码,您可能需要根据具体需求进行修改和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值