python如何设计工具类_TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包...

banner.png

68747470733a2f2f696d672e736869656c64732e696f2f6769746875622f6c6963656e73652f616972617269612f546578744272657765722e7376673f636f6c6f723d626c7565267374796c653d666c61742d73717561726568747470733a2f2f696d672e736869656c64732e696f2f776562736974653f646f776e5f6d6573736167653d6f66666c696e65266c6162656c3d446f63756d656e746174696f6e2675705f6d6573736167653d6f6e6c696e652675726c3d6874747073253341253246253246746578746272657765722e72656164746865646f63732e696f68747470733a2f2f696d672e736869656c64732e696f2f707970692f762f7465787462726577657268747470733a2f2f696d672e736869656c64732e696f2f6769746875622f762f72656c656173652f616972617269612f546578744272657765723f696e636c7564655f70726572656c6561736573

TextBrewer is a PyTorch-based model distillation toolkit for natural language processing. It includes various distillation techniques from both NLP and CV field and provides an easy-to-use distillation framework, which allows users to quickly experiment with the state-of-the-art distillation methods to compress the model with a relatively small sacrifice in the performance, increasing the inference speed and reducing the memory usage.

Update

Apr 26, 2020

Added Chinese NER task (MSRA NER) results.

Added results for distilling to T12-nano model, which has a similar strcuture to Electra-small.

Updated some results of CoNLL-2003, CMRC 2018 and DRCD.

Apr 22, 2020

Updated to 0.1.9 (added cache option which speeds up distillation; fixed some bugs). See details in releases.

Added experimential results for distilling Electra-base to Electra-small on Chinese tasks.

TextBrewer has been accepted by ACL 2020 as a demo paper, please use our new bib entry.

Mar 17, 2020

Added CoNLL-2003 English NER distillation example, see examples/conll2003_example.

Mar 11, 2020

Updated to 0.1.8 (Improvements on TrainingConfig and train method). See details in releases.

Mar 2, 2020

Initial public version 0.1.7 has been released. See details in releases.

Table of Contents

Section

Contents

Introduction to TextBrewer

How to install

Two stages of TextBrewer workflow

Example: distilling BERT-base to a 3-layer BERT

Distillation experiments on typical English and Chinese datasets

Brief explanations of the core concepts in TextBrewer

Frequently asked questions

Known issues

Citation to TextBrewer

Introduction

Textbrewer is designed for the knowledge distillation of NLP models. It provides various distillation methods and offers a distillation framework for quickly setting up experiments.

The main features of TextBrewer are:

Wide-support: it supports various model architectures (especially transformer-based models)

Flexibility: design your own distillation scheme by combining different techniques; it also supports user-defined loss functions, modules, etc.

Easy-to-use: users don't need to modify the model architectures

Built for NLP: it is suitable for a wide variety of NLP tasks: text classification, machine reading comprehension, sequence labeling, ...

TextBrewer currently is shipped with the following distillation techniques:

Mixed soft-label and hard-label training

Dynamic loss weight adjustment and temperature adjustment

Various distillation loss functions: hidden states MSE, attention-matrix-based loss, neuron selectivity transfer, ...

Freely adding intermediate features matching losses

Multi-teacher distillation

...

TextBrewer includes:

Distillers: the cores of distillation. Different distillers perform different distillation modes. There are GeneralDistiller, MultiTeacherDistiller, BasicTrainer, etc.

Configurations and presets: Configuration classes for training and distillation, and predefined distillation loss functions and strategies.

Utilities: auxiliary tools such as model parameters analysis.

To start distillation, users need to provide

the models (the trained teacher model and the un-trained student model)

datasets and experiment configurations

TextBrewer has achieved impressive results on several typical NLP tasks. See Experiments.

See Full Documentation for detailed usages.

Installation

Requirements

Python >= 3.6

PyTorch >= 1.1.0

TensorboardX or Tensorboard

NumPy

tqdm

Transformers >= 2.0 (optional, used by some examples)

Install from PyPI

pip install textbrewer

Install from the Github source

git clone https://github.com/airaria/TextBrewer.git

pip install ./textbrewer

Workflow

distillation_workflow_en.png

Stage 1: Preparation:

Train the teacher model

Define and initialize the student model

Construct a dataloader, an optimizer, and a learning rate scheduler

Stage 2: Distillation with TextBrewer:

Construct a TraningConfig and a DistillationConfig, initialize a distiller

Define an adaptor and a callback. The adaptor is used for adaptation of model inputs and outputs. The callback is called by the distiller during training

Call the train method of the distiller

Quickstart

Here we show the usage of TextBrewer by distilling BERT-base to a 3-layer BERT.

Before distillation, we assume users have provided:

A trained teacher model teacher_model (BERT-base) and a to-be-trained student model student_model (3-layer BERT).

a dataloader of the dataset, an optimizer and a learning rate builder or class scheduler_class and its args dict scheduler_dict.

Distill with TextBrewer:

import textbrewer

from textbrewer import GeneralDistiller

from textbrewer import TrainingConfig, DistillationConfig

# Show the statistics of model parameters

print("\nteacher_model's parametrers:")

result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)

print (result)

print("student_model's parametrers:")

result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)

print (result)

# Define an adaptor for translating the model inputs and outputs

def simple_adaptor(batch, model_outputs):

# The second and third elements of model outputs are the logits and hidden states

return {'logits': model_outputs[1],

'hidden': model_outputs[2]}

# Training configuration

train_config = TrainingConfig()

# Distillation configuration

# Matching different layers of the student and the teacher

distill_config = DistillationConfig(

intermediate_matches=[

{'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},

{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])

# Build distiller

distiller = GeneralDistiller(

train_config=train_config, distill_config = distill_config,

model_T = teacher_model, model_S = student_model,

adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)

# Start!

with distiller:

distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

Examples can be found in the examples directory :

examples/random_token_example : a simple runable toy example which demonstrates the usage of TextBrewer. This example performs distillation on the text classification task with random tokens as inputs.

examples/cmrc2018_example (Chinese): distillation on CMRC 2018, a Chinese MRC task, using DRCD as data augmentation.

examples/mnli_example (English): distillation on MNLI, an English sentence-pair classification task. This example also shows how to perform multi-teacher distillation.

examples/conll2003_example (English): distillation on CoNLL-2003 English NER task, which is in form of sequence labeling.

Experiments

We have performed distillation experiments on several typical English and Chinese NLP datasets. The setups and configurations are listed below.

Models

For English tasks, the teacher model is BERT-base-cased.

For Chinese tasks, the teacher models are RoBERTa-wwm-ext and Electra-base released by the Joint Laboratory of HIT and iFLYTEK Research.

We have tested different student models. To compare with public results, the student models are built with standard transformer blocks except for BiGRU which is a single-layer bidirectional GRU. The architectures are listed below. Note that the number of parameters includes the embedding layer but does not include the output layer of each specific task.

English models

Model

#Layers

Hidden size

Feed-forward size

#Params

Relative size

BERT-base-cased (teacher)

12

768

3072

108M

100%

T6 (student)

6

768

3072

65M

60%

T3 (student)

3

768

3072

44M

41%

T3-small (student)

3

384

1536

17M

16%

T4-Tiny (student)

4

312

1200

14M

13%

T12-nano (student)

12

256

1024

17M

16%

BiGRU (student)

-

768

-

31M

29%

Chinese models

Model

#Layers

Hidden size

Feed-forward size

#Params

Relative size

RoBERTa-wwm-ext (teacher)

12

768

3072

102M

100%

Electra-base (teacher)

12

768

3072

102M

100%

T3 (student)

3

768

3072

38M

37%

T3-small (student)

3

384

1536

14M

14%

T4-Tiny (student)

4

312

1200

11M

11%

Electra-small (student)

12

256

1024

12M

12%

T4-tiny archtecture is the same as TinyBERT[4].

T3 architecure is the same as BERT3-PKD[2].

Distillation Configurations

distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)

# Others arguments take the default values

matches are differnt for different models:

Model

matches

BiGRU

None

T6

L6_hidden_mse + L6_hidden_smmd

T3

L3_hidden_mse + L3_hidden_smmd

T3-small

L3n_hidden_mse + L3_hidden_smmd

T4-Tiny

L4t_hidden_mse + L4_hidden_smmd

T12-nano

small_hidden_mse + small_hidden_smmd

Electra-small

small_hidden_mse + small_hidden_smmd

The definitions of matches are at examples/matches/matches.py.

We use GeneralDistiller in all the distillation experiments.

Training Configurations

Learning rate is 1e-4 (unless otherwise specified).

We train all the models for 30~60 epochs.

Results on English Datasets

We experiment on the following typical Enlgish datasets:

Dataset

Task type

Metrics

#Train

#Dev

Note

text classification

m/mm Acc

393K

20K

sentence-pair 3-class classification

reading comprehension

EM/F1

88K

11K

span-extraction machine reading comprehension

sequence labeling

F1

23K

6K

named entity recognition

We list the public results from DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT and our results below for comparison.

Public results:

Model (public)

MNLI

SQuAD

CoNLL-2003

DistilBERT (T6)

81.6 / 81.1

78.1 / 86.2

-

BERT6-PKD (T6)

81.5 / 81.0

77.1 / 85.3

-

BERT-of-Theseus (T6)

82.4/ 82.1

-

-

BERT3-PKD (T3)

76.7 / 76.3

-

-

TinyBERT (T4-tiny)

82.8 / 82.9

72.7 / 82.1

-

Our results:

Model (ours)

MNLI

SQuAD

CoNLL-2003

BERT-base-cased (teacher)

83.7 / 84.0

81.5 / 88.6

91.1

BiGRU

-

-

85.3

T6

83.5 / 84.0

80.8 / 88.1

90.7

T3

81.8 / 82.7

76.4 / 84.9

87.5

T3-small

81.3 / 81.7

72.3 / 81.4

78.6

T4-tiny

82.0 / 82.6

75.2 / 84.0

89.1

T12-nano

83.2 / 83.9

79.0 / 86.6

89.6

Note:

The equivalent model structures of public models are shown in the brackets after their names.

When distilling to T4-tiny, NewsQA is used for data augmentation on SQuAD and HotpotQA is used for data augmentation on CoNLL-2003.

When distilling to T12-nano, HotpotQA is used for data augmentation on CoNLL-2003.

Results on Chinese Datasets

We experiment on the following typical Chinese datasets:

Dataset

Task type

Metrics

#Train

#Dev

Note

text classification

Acc

393K

2.5K

Chinese translation version of MNLI

text classification

Acc

239K

8.8K

sentence-pair matching, binary classification

reading comprehension

EM/F1

10K

3.4K

span-extraction machine reading comprehension

reading comprehension

EM/F1

27K

3.5K

span-extraction machine reading comprehension (Traditional Chinese)

sequence labeling

F1

45K

3.4K (#Test)

Chinese named entity recognition

The results are listed below.

Model

XNLI

LCQMC

CMRC 2018

DRCD

RoBERTa-wwm-ext (teacher)

79.9

89.4

68.8 / 86.4

86.5 / 92.5

T3

78.4

89.0

66.4 / 84.2

78.2 / 86.4

T3-small

76.0

88.1

58.0 / 79.3

75.8 / 84.8

T4-tiny

76.2

88.4

61.8 / 81.8

77.3 / 86.1

Model

XNLI

LCQMC

CMRC 2018

DRCD

MSRA NER

Electra-base (teacher))

77.8

89.8

65.6 / 84.7

86.9 / 92.3

95.14

Electra-small

77.7

89.3

66.5 / 84.9

85.5 / 91.3

93.48

Note:

Learning rate decay is not used in distillation on CMRC 2018 and DRCD.

CMRC 2018 and DRCD take each other as the augmentation dataset in the distillation.

The settings of training Electra-base teacher model can be found at Chinese-ELECTRA.

Electra-small student model is intialized with the pretrained weights.

Core Concepts

Configurations

TrainingConfig: configuration related to general deep learning model training

DistillationConfig: configuration related to distillation methods

Distillers

Distillers are in charge of conducting the actual experiments. The following distillers are available:

BasicDistiller: single-teacher single-task distillation, provides basic distillation strategies.

GeneralDistiller (Recommended): single-teacher single-task distillation, supports intermediate features matching. Recommended most of the time.

MultiTeacherDistiller: multi-teacher distillation, which distills multiple teacher models (of the same task) into a single student model. This class doesn't support Intermediate features matching.

MultiTaskDistiller: multi-task distillation, which distills multiple teacher models (of different tasks) into a single student. This class doesn't support Intermediate features matching.

BasicTrainer: Supervised training a single model on a labeled dataset, not for distillation. It can be used to train a teacher model.

User-Defined Functions

In TextBrewer, there are two functions that should be implemented by users: callback and adaptor.

Callback

At each checkpoint, after saving the student model, the callback function will be called by the distiller. A callback can be used to evaluate the performance of the student model at each checkpoint.

Adaptor

It converts the model inputs and outputs to the specified format so that they could be recognized by the distiller, and distillation losses can be computed. At each training step, batch and model outputs will be passed to the adaptor; the adaptor re-organizes the data and returns a dictionary.

For more details, see the explanations in Full Documentation.

FAQ

Q: How to initialize the student model?

A: The student model could be randomly initialized (i.e., with no prior knowledge) or be initialized by pre-trained weights. For example, when distilling a BERT-base model to a 3-layer BERT, you could initialize the student model with RBT3 (for Chinese tasks) or the first three layers of BERT (for English tasks) to avoid cold start problem. We recommend that users use pre-trained student models whenever possible to fully take advantage of large-scale pre-training.

Q: How to set training hyperparameters for the distillation experiments?

A: Knowledge distillation usually requires more training epochs and larger learning rate than training on the labeled dataset. For example, training SQuAD on BERT-base usually takes 3 epochs with lr=3e-5; however, distillation takes 30~50 epochs with lr=1e-4. The conclusions are based on our experiments, and you are advised to try on your own data.

Known Issues

Compatibility with FP16 training has not been tested.

Multi-GPU training support is only available through DataParallel currently.

Citation

If you find TextBrewer is helpful, please cite our paper:

@InProceedings{textbrewer-acl2020-demo,

author = "Yang, Ziqing and Cui, Yiming and Chen, Zhipeng and Che, Wanxiang and Liu, Ting and Wang, Shijin and Hu, Guoping",

title = "{T}ext{B}rewer: {A}n {O}pen-{S}ource {K}nowledge {D}istillation {T}oolkit for {N}atural {L}anguage {P}rocessing",

booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",

year = "2020",

publisher = "Association for Computational Linguistics"

}

Follow Us

Follow our official WeChat account to keep updated with our latest technologies!

hfl_qrcode.jpg

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值