DeBERTa (Decoding-enhanced BERT with disentangled attention)

51 篇文章 5 订阅
33 篇文章 4 订阅

1 简介

DeBERTa (Decoding-enhanced BERT with disentangled attention),如名字所述,有两个技术来改善BERT和RoBERTa,第一个是disentangled attention(解开的注意力),每个单词使用两个向量分别编码文本和位置,在单词们之间的注意力权重是通过使用它们的文本和相对位置的解开矩阵分别进行计算的。第2个技术是使用了一个加强的mask decoder,在decoding层引入绝对位置来预测masked tokens。

本文根据2021《DEBERTA: DECODING-ENHANCED BERT WITH DISENTANGLED ATTENTION》翻译总结。

DeBERTa目前(2021-02-08)在GLUE 基准上第一名。大版本的DeBERTa(15亿参数)在SuperGLUE 基准的平均分数上第一次超越人类。不过这并不意味着模型达到了人类知识水平,人类拥有更好的合成综合能力,利用学到的知识处理新的问题

RNN处理文本是按顺序的,而transformer采用self-attention可以并行处理输入文本的每个单词。所以对于大规模的模型训练,transformer更好。

Disentangled attention:比如deep和learning这两个词,当它们相连出现在一起的时候,其之间的依赖性会更强,而当它们出现在不同句子里,其之间依赖性就不强。所以相对位置的注意力很有用。

Enhanced mask decoder:Disentangled attention考虑的是相对位置,所以我们还需考虑绝对位置。

2 背景知识

2.1 transformer位置编码

以前的方法是添加一个位置偏置到每一个输入单词embedding,所以每一个单词用一个向量表示,其依赖于文本和位置。这个位置偏置可以采用绝对位置embedding或者相对位置embedding。目前发现相对位置对语言理解和生成任务是更加有效的。

2.2 MASKED LANGUAGE MODEL(MLM)

在这里插入图片描述

2.3 Self-attention

在这里插入图片描述

3 DeBERTa模型结构

3.1 DISENTANGLED ATTENTION

在这里插入图片描述

上面公式的4项目分别对应着content-to-content, content-to-position, position-to-content, and position-to-position。
我们发现最后一项位置对位置(position-to-position)的注意力没有太多用,故我们用了前面三个。

disentangled self-attention with relative position公式如下,形式类似于2.3节:
在这里插入图片描述

其中当用k表示最大相对距离时,token i和token j的相对距离公式δ(i,j)∈[ 0,2k ),定义如下:
在这里插入图片描述

最终算法如下,空间复杂度O(kd):
在这里插入图片描述

3.2 ENHANCED MASK DECODER考虑绝对单词位置

例如语句“a new store opened beside the new mall”,仅使用相对位置不是能有效区分‘store’和‘mall’,因为它们两和‘new’有相同的绝对位置。所以我们需引入绝对位置。

有两种方法引入绝对位置。BERT 是在输入层引入绝对位置。而在DeBERTa,我们是在transformer层之后,在softmax 层(masked token 预测)之前,引入绝对位置。如下图:
在这里插入图片描述

这样的话,DeBERTa在所有的transformer层捕捉相对位置,仅当编码masked 单词时,使用绝对位置作为补充。因此,我们叫DeBERTa的编码单元为enhanced mask decoder(EMD)。
DeBERTa使用绝对位置的方法比BERT的好,我们推断在BERT中较早的引入绝对位置可能伤害了模型,使其不能足够的学习相对位置。

3.3 尺度不变微调

引入了一个虚拟的对抗训练,Scale-invariant-Fine-Tuning (SiFT),进行模型微调。
虚拟的对抗训练是一个正则化的方法来改善模型的泛化能力。
在输入中增加干扰。我们采用SiFT算法,在normalized word embedding上添加干扰。

4 实验结果

在这里插入图片描述

4.1 ABLATION STUDY

• -EMD is the DeBERTa base model without EMD.
• -C2P is the DeBERTa base model without the content-to-position term (© in Eq. 4).
• -P2C is the DeBERTa base model without the position-to-content term ((b) in Eq. 4). As XLNet also uses the relative position bias, this model is close to XLNet plus EMD.

可以看到如果没有EMD、C2P、P2C,模型的效果都会降低。
在这里插入图片描述

4.2 15亿(1.5Billion)参数的模型

DeBERTa_1.5B: 48层,hidden 大小等于1536,24个注意力头。训练数据有160G.
不过T5有11 billion参数,DeBERTa_1.5B参数量还是很小的,效果也更好。如下表:

在这里插入图片描述

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Sure, here's an example of how to use the libavcodec API to decode audio: ```c // include necessary headers #include <stdio.h> #include <stdlib.h> #include <stdint.h> #include <math.h> #include <libavutil/frame.h> #include <libavcodec/avcodec.h> // define decoder context and packet AVCodecContext *dec_ctx = NULL; AVPacket *pkt = NULL; // define output frame AVFrame *decoded_frame = NULL; // define buffer for audio samples int16_t *audio_buf = NULL; int audio_buf_size = 0; // define audio parameters int audio_stream_index = -1; int audio_sample_rate = 0; int audio_channels = 0; int audio_frame_size = 0; // function to initialize the decoder int init_decoder(const char *filename) { int ret; // initialize the codec and format context AVFormatContext *fmt_ctx = NULL; if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) { fprintf(stderr, "Error opening input file: %s\n", av_err2str(ret)); return ret; } if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) { fprintf(stderr, "Error finding stream information: %s\n", av_err2str(ret)); avformat_close_input(&fmt_ctx); return ret; } // find the audio stream for (int i = 0; i < fmt_ctx->nb_streams; i++) { if (fmt_ctx->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { audio_stream_index = i; break; } } if (audio_stream_index == -1) { fprintf(stderr, "Error finding audio stream\n"); avformat_close_input(&fmt_ctx); return AVERROR_STREAM_NOT_FOUND; } // get the audio parameters dec_ctx = avcodec_alloc_context3(NULL); if (!dec_ctx) { fprintf(stderr, "Error allocating codec context\n"); avformat_close_input(&fmt_ctx); return AVERROR(ENOMEM); } avcodec_parameters_to_context(dec_ctx, fmt_ctx->streams[audio_stream_index]->codecpar); av_codec_set_pkt_timebase(dec_ctx, fmt_ctx->streams[audio_stream_index]->time_base); AVCodec *codec = avcodec_find_decoder(dec_ctx->codec_id); if (!codec) { fprintf(stderr, "Error finding codec\n"); avcodec_free_context(&dec_ctx); avformat_close_input(&fmt_ctx); return AVERROR_DECODER_NOT_FOUND; } if ((ret = avcodec_open2(dec_ctx, codec, NULL)) < 0) { fprintf(stderr, "Error opening codec: %s\n", av_err2str(ret)); avcodec_free_context(&dec_ctx); avformat_close_input(&fmt_ctx); return ret; } audio_sample_rate = dec_ctx->sample_rate; audio_channels = dec_ctx->channels; audio_frame_size = dec_ctx->frame_size; // allocate the packet and output frame pkt = av_packet_alloc(); decoded_frame = av_frame_alloc(); if (!pkt || !decoded_frame) { fprintf(stderr, "Error allocating packet or frame\n"); avcodec_free_context(&dec_ctx); av_packet_free(&pkt); av_frame_free(&decoded_frame); avformat_close_input(&fmt_ctx); return AVERROR(ENOMEM); } // close the format context avformat_close_input(&fmt_ctx); return 0; } // function to decode the next audio packet int decode_packet() { int ret; // read the next packet if ((ret = av_read_frame(fmt_ctx, pkt)) < 0) { if (ret == AVERROR_EOF) { return 0; } else { fprintf(stderr, "Error reading packet: %s\n", av_err2str(ret)); return ret; } } // check if the packet is an audio packet if (pkt->stream_index != audio_stream_index) { av_packet_unref(pkt); return 1; } // send the packet to the decoder if ((ret = avcodec_send_packet(dec_ctx, pkt)) < 0) { fprintf(stderr, "Error sending packet to decoder: %s\n", av_err2str(ret)); av_packet_unref(pkt); return ret; } av_packet_unref(pkt); // read the decoded frames while (ret >= 0) { ret = avcodec_receive_frame(dec_ctx, decoded_frame); if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { break; } else if (ret < 0) { fprintf(stderr, "Error receiving frame from decoder: %s\n", av_err2str(ret)); return ret; } // allocate space for the audio buffer int buf_size = av_samples_get_buffer_size(NULL, audio_channels, decoded_frame->nb_samples, AV_SAMPLE_FMT_S16, 1); if (buf_size <= 0) { fprintf(stderr, "Error getting buffer size\n"); return AVERROR_INVALIDDATA; } if (!audio_buf) { audio_buf = (int16_t *) malloc(buf_size); audio_buf_size = buf_size; } else if (buf_size > audio_buf_size) { audio_buf = (int16_t *) realloc(audio_buf, buf_size); audio_buf_size = buf_size; } if (!audio_buf) { fprintf(stderr, "Error allocating audio buffer\n"); return AVERROR(ENOMEM); } // convert the audio samples to signed 16-bit integers int16_t *samples = (int16_t *) decoded_frame->data[0]; for (int i = 0; i < decoded_frame->nb_samples * audio_channels; i++) { audio_buf[i] = samples[i]; } } return 1; } // function to clean up resources void cleanup() { if (dec_ctx) { avcodec_free_context(&dec_ctx); } if (pkt) { av_packet_free(&pkt); } if (decoded_frame) { av_frame_free(&decoded_frame); } if (audio_buf) { free(audio_buf); } } // main function int main(int argc, char *argv[]) { if (argc != 2) { fprintf(stderr, "Usage: %s <input file>\n", argv[0]); return 1; } if (init_decoder(argv[1]) < 0) { cleanup(); return 1; } int ret; while ((ret = decode_packet()) > 0) { // do something with the audio samples } if (ret < 0) { cleanup(); return 1; } cleanup(); return 0; } ``` This example initializes the decoder, reads packets from the input file, sends them to the decoder, and converts the decoded audio samples to signed 16-bit integers. You can modify the `do something with the audio samples` section to perform whatever processing you need on the audio data.

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值