Knowledge Distillation 知识蒸馏详解

本文探讨了知识蒸馏如何通过将大型复杂模型的知识转移给小型网络,实现模型压缩和性能保留。知识蒸馏利用教师模型的软目标概率分布来指导学生模型的训练,使得学生模型能在有限的计算资源上达到与教师模型相近的预测能力,甚至在某些情况下超越教师模型。
摘要由CSDN通过智能技术生成

往期文章链接目录

Currently, especially in NLP, very large scale models are being trained. A large portion of those can’t even fit on an average person’s hardware. We can train a small network that can run on the limited computational resource of our mobile device. But small models can’t extract many complex features that can be handy in generating predictions unless you devise some elegant algorithm to do so. Plus, due to the Law of diminishing returns, a great increase in the size of model barely maps to a small increase in the accuracy.

There are currently two ways to solve this problem:

  • Knowledge Distillation.
  • Model Compression.

In this blog, I focus on talking about Knowledge Distillation. Using distillation, one could reduce the size of models like BERT by 87% and still retain 96% of its performance. Recent work even suggests that students can actually exceed teacher performance.

Shortcoming of normal neural networks

Take an example of MNIST dataset. Let’s pick a sample picture of number 3.

In training data, the number 3 translates to a corresponding one-hot-vector:
[0 0 0 1 0 0 0 0 0 0]. This vector simply tells that the number in that image in 3 but fails to explicitly mention anything about the shape of number 3. Like the shape of 3 is similar to 8. Hence, neural network is never explicitly being asked to learn the generalized understanding of the training data.

Generalization of Information

The goal of a neural network is to predict the output for samples that the network had never seen during training by generalizing the knowledge within the training data. Taking the example of a discriminative neural network whose objective is to identify the number in a picture. Now the neural network returns distribution of probabilities across all classes 0, 1, 2, ..., 9 and this tells us a lot about the capability of the network to generalize over the concepts within the training data.

For a decently trained neural network on MNIST,

  • even though the probability for number 3 is significantly greater than the probability for the number 8 and number 0

  • Probability of 8 and 0 are comparable

  • still the probabilities of 8 and 0 are comparatively higher than other numbers.

So, the neural network is able to identify that the shape of the number in that image is 3 but the neural network also suggests that the shape of 3 is quite similar to the shape of numbers 8 and 0.

In the above example, we usually train a large and complex network or an ensemble model which can extract important features from the image data and can, therefore, produce better predictions.

However, these models are mostly very cumbersome (aka cumbersome model/network, which means deep and complex) Its deepness gives the ability to extract complex features and its complexity gives it the power to remain accurate. But the model is heavy enough that one need a large amount of memory and a powerful GPU to perform large and complex calculations. So that’s why we need to transfer the knowledge learned by this model to a much smaller model which can easily be used in mobile.

Knowledge Distillation

A few Definitions

  • soft targets: network’s probability/weight distribution across all classes.

  • hard targets: one-hot vector representation within the original training data.

  • Transfer-Set: pass the data through the cumbersome model and use its output (probability distribution) as the respective truth values. It can consist of the dataset used to train the original model, new dataset or both.

General idea of knowledge distillation

Knowledge distillation is a simple way to improve the performance of deep learning models on mobile devices. In this process, we train a large and complex network or an ensemble model which can extract important features from the given data and can, therefore, produce better predictions. Then we train a small network with the help of the cumbersome model. This small network will be able to produce comparable results, and in some cases, it can even be made capable of replicating the results of the cumbersome network.

You can distill the large and complex network in another much smaller network, and the smaller network does a reasonable job of approximating the original function learned by a deep network.

Teacher and Student

The distilled model (student), is trained to mimic the output of the larger network (teacher), instead of training it on the raw data directly.

The point is that the teacher is outputting class probabilities — soft labels rather than hard labels. A number classifier (classify 0,3,8) might say “0: 0.1, 3: 0.75, 8: 0.15” instead of “0: 0, 3: 1, 8: 0”. Why bother? Because these “soft labels” are more informative than the original ones — telling the student that 3 does very slightly resemble 0 or 8.

Student models can often come very close to teacher-level performance. Recent work even suggests that students can actually exceed teacher performance.

Temperature & Entropy

Temperature and Entropy are what we learned in physics and we know that Entropy increases with Temperature.

When soft-targets have high entropy, they give much more information per-training sample than hard-targets. For example, the soft targets “0: 0.1, 3: 0.75, 8: 0.15”, contains information such as 0 and 8 are somehow similar. However, hard targets
0: 0, 3: 1, 8: 0” does not contain such relation between 0 and 8.

However, the soft-targets would be less useful if the probability distribution of the output has low entropy (e.x.0: 0.01, 3: 0.98, 8: 0.01”). If this is the case, we need to raise its entropy and make it more informative.

Specifically, we use a parameter Temperature (T) to adjust the level of entropy and the formula is

q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) (1) q_{i}=\frac{\ex

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值