Distilling the Knowledge in a Neural Network 论文笔记蒸馏

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/bryant_meng/article/details/79260165
            </div>
                                                <!--一个博主专栏付费入口-->
                      <!--一个博主专栏付费入口结束-->
        <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-833878f763.css">
                                    <div id="content_views" class="markdown_views prism-github-gist">
                <!-- flowchart 箭头图标 勿删 -->
                <svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
                    <path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path>
                </svg>
                                        <p><img src="https://img-blog.csdnimg.cn/20181119143732887.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2JyeWFudF9tZW5n,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"></p>

arXiv-2015
In NIPS Deep Learning Workshop, 2014



本文只涉及《Distilling the Knowledge in a Neural Network》有关分类的部分,更多相关paper可以参考《Paper》

1 Background and Motivation

提高模型的 performance 一个很简单的思路是
train many different models on the same data and then to average their predictions

缺点

  • 用 ensemble 来预测结果太 cumbersome
  • 可能由于计算成本太高而无法部署到大量用户中,特别是如果单个模型是大型神经网络的话

Caruana 证实了 ensemble model to single model 的可行性
(demonstrate convincingly that the knowledge acquired by a large ensemble of models can be transferred to a single small model)

作者采用 knowledge distilling 的方法(全新的压缩方法)来实现这个过程(ensemble model to single model)

在这里插入图片描述

2 Conceptual block

在这里插入图片描述
1)对于模型学到的知识有个思想误区,这些知识常被认为是模型中已经训练好的参数。这种狭隘的思想曾一度阻碍了灌输学习的发展,因为一旦网络模型的结构发生变化,其所谓的知识/参数便无法得到有效利用。文中作者提出了对知识的更加宏观、抽象的理解,知识即为网络学习到的从输入vectors 到输出 vectors 之间的一种映射关系

这样理解的话就不局限于模型的具体结构,使得小网络学习大网络成为可能!

2)另外一个思想误区是训练的目标函数应该尽量贴近真实值。尽管如此,尽管如此,模型训练的目的是让模型在训练数据集上表现尽可能好,而实际的目的却是模型在新数据上的泛化能力。显然,如果我们能够训练模型,从而使之具有优越的泛化性能,那真真是极好的!可是这几乎是不可能的因为关于泛化的信息难以获取。然而,在进行知识灌输时,大模学到的泛化能力可以很自然地传输给小模,由于大模体型庞大泛化能力出色,由他带出来的小模的泛化能力肯定比从头训练小模效果要好很多。

那么大模型的泛化性能是怎么传给小模型的呢? 通过 soft target,大网络 softmax 输出(传统 softmax 加上 temperature) 作为 label,这就是 soft target ,用小网络的 softmax 输出去逼近大网络的 softmax 的输出。对应 hard target 就是原数据集的标签。soft target 比 hard target 好的地方如上面的 PPT。

为什么说 soft target 就包含了模型泛化性能的信息呢? 个人理解是,soft target 相对 hard target 有更多的类类关系

3 Knowledge Distilling

在这里插入图片描述

3.1 hard target

我们先看一下 hard target (softmax)的计算
在这里插入图片描述
更形象一点(来自知乎)
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.2 soft target

再看下 soft target (softmax + T)的效果
在这里插入图片描述
在这里插入图片描述
横坐标是温度 T,纵坐标是 soft target 的输出 q i q i q i qiqi q_i qiqiqizi21(zivi)2=zivi

3.4 softmax+T 相比 logits 的优势

既然学 logits 和学 softmax+T 的一种特例,那么 学 softmax + T 相比之下,有哪些优势呢?
作者做出了如下总结

  • logits are almost completely unconstrained by the cost function used for training the cumbersome model so they could be very noisy
  • very negative logits may convey useful information about the knowledge acquired by the cumbersome model

3.5 Cost function

小网络的损失函数如下
在这里插入图片描述
从大网络学泛化性能的时候,用比较大的T(T越大,越不自信,如果在这种不自信的情况下还能辨认类别,当测试的时候T=1,就会表现的更好,类比负重训练)训练,学真实数据的时候,用T = 1

将真实标签与soft target结合起来,采用二者的加权和作为目标标签可以获得更好的效果。从而目标函数转化为下式,其中,λ取小于1的数值时效果较好。

4 Dataset

MNIST

5 Experiments

网络结构:

  • 大网络:2个隐含层,每层1200个单元,55000训练样本。用dropout训练。
  • 小网络1(常规):2个隐含层,每层800个单元,无正则化。采用常规方式直接训练。
  • 小网络1(soft):2个隐含层,每层800个单元,无正则化。采用知识灌输法,师从大模进行训练。T=20。

错误个数对比:

  • 大网络:67
  • 小网络1(常规):146
  • 小网络1(soft):74

泛化性能的实验
为了研究小网络的泛化能力,作者将所有数字3的图片从transfer set 数据(训练小网络的数据集,可以比训练大网络的数据集小,也可以为空)集中删除,也就是说小网络在训练过程中从未见过3这个数字。尽管如何,在测试中发现,小网络对于数字3取得了高达98.6%的准确率。另外,即使transfer set数据集仅包含数字7和数字8的图片,小模的错误率仅有13.2%。说明,小网络从大网络那里继承了泛化性能!

Q1:论文中第三节,调整实验的时候改变 bias 怎么理解?

6 References

【1】【论文导读】Hinton - Distilling the Knowledge in a Neural Network

【2】手打例子一步一步带你看懂softmax函数以及相关求导过程

【3】知识蒸馏(Distillation)相关论文阅读(1)——Distilling the Knowledge in a Neural Network(以及代码复现)

7 Appendix

A. softmax 加 temperature 后的变化

在这里插入图片描述

import math
import numpy as np
import matplotlib.pyplot as plt
T = np.arange(1,20,1)
y1 = (math.e**(0.9/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y2 = (math.e**(0.07/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y3 = (math.e**(0.03/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
plt.plot(T,y1)
plt.plot(T,y2)
plt.plot(T,y3)
plt.legend(["0.9", "0.07","0.03"])# 图例
plt.grid()#网格
#plt.savefig('1.png')
plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

B. knowledge distilling (MNIST)代码

MNIST training data is 60000,为什么这里是 55000,还有 5000 是 validation data

B.1 teacher network

2个隐含层,每层1200个单元,55000训练样本。用dropout = 0.5 训练

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
%matplotlib inline

random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)

# 载入数据集
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

网络层的定义

# hyper parameters
n_epochs = 50
batch_size = 50
num_nodes_h1 = 1200
num_nodes_h2 = 1200
learning_rate = 0.001

# number of batches
n_batches = len(mnist.train.images) // batch_size # 55000

# 定义 W
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)

# 定义 b
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)

# 定义 soft max with T
def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp), axis=axis, keep_dims=True)
return _softmax

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

网络结构的设计

# data
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)# drop out

# input to hidden layer 1
W_h1 = weight_variable([784, num_nodes_h1])# 784,1200
b_h1 = bias_variable([num_nodes_h1])# 1200
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1) # relu(wx+b)
h1_drop = tf.nn.dropout(h1, keep_prob) # drop out

# hidden layer 1 to hidden layer 2
W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])# 1200,1200
b_h2 = bias_variable([num_nodes_h2])# 1200
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)# relu(wx+b)
h2_drop = tf.nn.dropout(h2, keep_prob) # drop out

# hidden layer 2 to output layer
W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output

y = tf.nn.softmax(logits) # hard target
y_soft_target = softmax_with_temperature(logits, temp=2.0) # soft target
loss = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

用 mini-batch 开始训练,并把训练的模型保留下来,训练的 loss,训练测试的 accuracy 记录下来

saver = tf.train.Saver()
losses = []
accs = []
test_accs = []

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(n_epochs):# epoch
x_shuffle, y_shuffle = shuffle(mnist.train.images, mnist.train.labels)
for i in range(n_batches):# batches
start = i * batch_size
end = start + batch_size
batch_x, batch_y = x_shuffle[start:end], y_shuffle[start:end]
sess.run(train_step, feed_dict={
x: batch_x, y_: batch_y, keep_prob:0.5})
train_loss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, keep_prob:0.5})
train_accuracy = sess.run(accuracy, feed_dict={
x: batch_x, y_: batch_y, keep_prob:1.0})
test_accuracy = sess.run(accuracy, feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0})
print(“Epoch : %i, train loss : %f, Accuracy: %f, Test accuracy: %f” % (
epoch+1, train_loss, train_accuracy, test_accuracy))
saver.save(sess, “/root/userfolder/Experiment/tensorflow-distillation-examples/model_teacher/”,
global_step=epoch+1)# 只保留最新的几个 epoch
losses.append(train_loss)
accs.append(train_accuracy)
test_accs.append(test_accuracy)
print("… completed!")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

output

Epoch : 1, train loss : 0.737658, Accuracy: 0.880000, Test accuracy: 0.870400
Epoch : 2, train loss : 0.761208, Accuracy: 0.900000, Test accuracy: 0.877700
Epoch : 3, train loss : 0.589437, Accuracy: 0.920000, Test accuracy: 0.890600
Epoch : 4, train loss : 0.643363, Accuracy: 0.900000, Test accuracy: 0.899900
Epoch : 5, train loss : 0.616038, Accuracy: 0.900000, Test accuracy: 0.900900
Epoch : 6, train loss : 0.611822, Accuracy: 0.860000, Test accuracy: 0.907100
Epoch : 7, train loss : 0.644078, Accuracy: 0.860000, Test accuracy: 0.909100
Epoch : 8, train loss : 0.402896, Accuracy: 0.960000, Test accuracy: 0.911100
Epoch : 9, train loss : 0.572901, Accuracy: 0.960000, Test accuracy: 0.907900
Epoch : 10, train loss : 0.517088, Accuracy: 0.900000, Test accuracy: 0.914600
Epoch : 11, train loss : 0.410240, Accuracy: 0.960000, Test accuracy: 0.914300
Epoch : 12, train loss : 0.945823, Accuracy: 0.800000, Test accuracy: 0.916200
Epoch : 13, train loss : 0.579927, Accuracy: 0.900000, Test accuracy: 0.917000
Epoch : 14, train loss : 0.503660, Accuracy: 0.860000, Test accuracy: 0.918300
Epoch : 15, train loss : 0.532867, Accuracy: 0.940000, Test accuracy: 0.918600
Epoch : 16, train loss : 0.430909, Accuracy: 0.940000, Test accuracy: 0.920300
Epoch : 17, train loss : 0.507866, Accuracy: 0.920000, Test accuracy: 0.920600
Epoch : 18, train loss : 0.453426, Accuracy: 0.920000, Test accuracy: 0.925200
Epoch : 19, train loss : 0.689311, Accuracy: 0.920000, Test accuracy: 0.926600
Epoch : 20, train loss : 0.379545, Accuracy: 0.940000, Test accuracy: 0.926100
Epoch : 21, train loss : 0.431786, Accuracy: 0.920000, Test accuracy: 0.926800
Epoch : 22, train loss : 0.401257, Accuracy: 0.960000, Test accuracy: 0.927300
Epoch : 23, train loss : 0.587902, Accuracy: 0.960000, Test accuracy: 0.928600
Epoch : 24, train loss : 0.620417, Accuracy: 0.880000, Test accuracy: 0.927400
Epoch : 25, train loss : 0.365211, Accuracy: 0.940000, Test accuracy: 0.929500
Epoch : 26, train loss : 0.427130, Accuracy: 0.960000, Test accuracy: 0.930300
Epoch : 27, train loss : 0.253452, Accuracy: 0.900000, Test accuracy: 0.930800
Epoch : 28, train loss : 0.427312, Accuracy: 0.920000, Test accuracy: 0.930900
Epoch : 29, train loss : 0.419188, Accuracy: 0.900000, Test accuracy: 0.933100
Epoch : 30, train loss : 0.268312, Accuracy: 0.940000, Test accuracy: 0.933800
Epoch : 31, train loss : 0.346375, Accuracy: 0.920000, Test accuracy: 0.933500
Epoch : 32, train loss : 0.292108, Accuracy: 0.960000, Test accuracy: 0.933000
Epoch : 33, train loss : 0.436444, Accuracy: 0.960000, Test accuracy: 0.935100
Epoch : 34, train loss : 0.278850, Accuracy: 0.940000, Test accuracy: 0.934900
Epoch : 35, train loss : 0.277737, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 36, train loss : 0.425431, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 37, train loss : 0.359413, Accuracy: 0.940000, Test accuracy: 0.937800
Epoch : 38, train loss : 0.338502, Accuracy: 0.960000, Test accuracy: 0.937600
Epoch : 39, train loss : 0.433313, Accuracy: 0.880000, Test accuracy: 0.937100
Epoch : 40, train loss : 0.529199, Accuracy: 0.860000, Test accuracy: 0.938700
Epoch : 41, train loss : 0.657401, Accuracy: 0.920000, Test accuracy: 0.938500
Epoch : 42, train loss : 0.491150, Accuracy: 0.920000, Test accuracy: 0.938600
Epoch : 43, train loss : 0.334091, Accuracy: 0.940000, Test accuracy: 0.940200
Epoch : 44, train loss : 0.298908, Accuracy: 0.940000, Test accuracy: 0.941000
Epoch : 45, train loss : 0.303939, Accuracy: 0.940000, Test accuracy: 0.939800
Epoch : 46, train loss : 0.378838, Accuracy: 0.940000, Test accuracy: 0.939500
Epoch : 47, train loss : 0.323622, Accuracy: 0.920000, Test accuracy: 0.941700
Epoch : 48, train loss : 0.280403, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 49, train loss : 0.390651, Accuracy: 0.920000, Test accuracy: 0.942800
Epoch : 50, train loss : 0.614632, Accuracy: 0.900000, Test accuracy: 0.941700
... completed!

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

可视化训练的loss

#查看训练的损失变化
plt.title("Loss of teacher")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述
可视化训练和测试的 accuracy

# 查看训练精度和测试精度的变化
plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

把 训练的 loss,训练测试的 accuracy 保存下来

# 保存训练loss 和 accuracy 以及测试的 accuracy 
np.save("loss_teacher.npy", np.array(losses))
np.save("acc_train_teacher.npy", np.array(accs))
np.save("acc_test_teacher.npy", np.array(test_accs))

 
 
  • 1
  • 2
  • 3
  • 4

保存 teacher network 的soft target,我们选择表现好一点 epoch 训练结果,下面的保存的 第48个 epoch

# 保存 第48个 epoch 的soft target
_soft_targets = []
with tf.Session() as sess:
    saver.restore(sess, "./model_teacher/-48")
    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))
    for i in range(n_batches):
        start = i * batch_size
        end = start + batch_size
        batch_x = mnist.train.images[start:end]
        soft_target = sess.run(y_soft_target, feed_dict={x: batch_x, keep_prob:1.0})
        _soft_targets.append(soft_target)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

看下 _sotf_targets 的形式,reshape一下

np.shape(_soft_targets)# (1100, 50, 10) = (batch,batch_size,classes)
soft_targets  = np.c_[_soft_targets].reshape(55000, 10)# reshape (5500,10)

 
 
  • 1
  • 2

对比下 soft target 和 hard target

print(soft_targets[:2])
print(mnist.train.labels[:2]) # label 可以和上面的softmax 预测结果对比一下

 
 
  • 1
  • 2

output

[[5.2621812e-03 6.1693429e-03 1.5207376e-01 6.1155759e-02 1.4845385e-02
  4.8464271e-03 3.6828788e-03 6.0641229e-01 2.9818511e-02 1.1573344e-01]
 [2.4089564e-03 2.6752956e-03 1.8253580e-02 8.5861373e-01 3.0618338e-04
  1.7423177e-02 9.3506598e-05 3.6187540e-03 8.3464541e-02 1.3142269e-02]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

保存 teacher network 的 soft target,方便 student network learning

np.save('soft-targets.npy', soft_targets)

 
 
  • 1

查看其 shape

np.load(file="soft-targets.npy").shape

 
 
  • 1

output

(55000, 10)

 
 
  • 1

B.2 student network

和 teacher network 的区别是 hidden layer 的大小(1200 to 600,论文中是800),以及loss的变化,其它一样

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt

%matplotlib inline

random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)

mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

载入 teacher network 的 soft target

soft_targets = np.load(file="soft-targets.npy")
print(np.shape(soft_targets))

 
 
  • 1
  • 2

output

(55000, 10)

 
 
  • 1

hyper parameters 设置,W,b ,soft target 的定义

n_epochs = 50
batch_size = 50
num_nodes_h1 = 600 # Before 800
num_nodes_h2 = 600 # Before 800
learning_rate = 0.001

n_batches = len(mnist.train.images) // batch_size

def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)

def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)

def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp),
axis=axis, keep_dims=True)
return _softmax

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

网络的设计

x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
soft_target_ = tf.placeholder(tf.float32, [None, 10])

keep_prob = tf.placeholder(tf.float32)
T = tf.placeholder(tf.float32)

W_h1 = weight_variable([784, num_nodes_h1])
b_h1 = bias_variable([num_nodes_h1])
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1)
h1_drop = tf.nn.dropout(h1, keep_prob)

W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])
b_h2 = bias_variable([num_nodes_h2])
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)
h2_drop = tf.nn.dropout(h2, keep_prob)# 还是用了drop out

W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output

y = tf.nn.softmax(logits)
y_soft_target = softmax_with_temperature(logits, temp=T)

loss_hard_target = -tf.reduce_sum(y_ tf.log(y), reduction_indices=[1])
loss_soft_target = -tf.reduce_sum(soft_target_ tf.log(y_soft_target),
reduction_indices=[1])

loss = tf.reduce_mean(tf.square(T) loss_hard_target + tf.square(T) loss_soft_target)

train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

开始训练,高温训练,低温测试

saver = tf.train.Saver()
losses = []
accs = []
test_accs = []
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(n_epochs):
        x_shuffle, y_shuffle, soft_targets_shuffle \
                = shuffle(mnist.train.images, mnist.train.labels, soft_targets)
        for i in range(n_batches):
            start = i * batch_size
            end = start + batch_size
            batch_x, batch_y, batch_soft_targets \
                    = x_shuffle[start:end], y_shuffle[start:end], soft_targets_shuffle[start:end]
            sess.run(train_step, feed_dict={
                                            x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, 
                                            keep_prob:0.5, T:2.0})
        train_loss = sess.run(loss, feed_dict={
                                            x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, 
                                            keep_prob:0.5, T:2.0})# 高温训练
        train_accuracy = sess.run(accuracy, feed_dict={
                                            x: batch_x, y_: batch_y, keep_prob:1.0, T:1.0})
        test_accuracy = sess.run(accuracy, feed_dict={
                                            x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0, T:1.0})
        # 低温测试
        print("Epoch : %i, Loss : %f, Accuracy: %f, Test accuracy: %f" % (
                epoch+1, train_loss, train_accuracy, test_accuracy))
        saver.save(sess, "/root/userfolder/Experiment/tensorflow-distillation-examples/model_student/", 
                   global_step=epoch+1)
        losses.append(train_loss)
        accs.append(train_accuracy)
        test_accs.append(test_accuracy)
    print("... completed!")

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

output,可以看出,结果青出于蓝

Epoch : 1, Loss : 7.137307, Accuracy: 0.860000, Test accuracy: 0.868200
Epoch : 2, Loss : 5.926404, Accuracy: 0.940000, Test accuracy: 0.892200
Epoch : 3, Loss : 5.597841, Accuracy: 0.920000, Test accuracy: 0.901400
Epoch : 4, Loss : 5.938632, Accuracy: 0.920000, Test accuracy: 0.913000
Epoch : 5, Loss : 5.872798, Accuracy: 0.920000, Test accuracy: 0.915800
Epoch : 6, Loss : 5.436497, Accuracy: 0.920000, Test accuracy: 0.919300
Epoch : 7, Loss : 5.455486, Accuracy: 0.880000, Test accuracy: 0.924100
Epoch : 8, Loss : 4.402141, Accuracy: 0.980000, Test accuracy: 0.927100
Epoch : 9, Loss : 5.413333, Accuracy: 0.960000, Test accuracy: 0.929700
Epoch : 10, Loss : 4.503023, Accuracy: 0.960000, Test accuracy: 0.931900
Epoch : 11, Loss : 4.971416, Accuracy: 0.960000, Test accuracy: 0.934800
Epoch : 12, Loss : 6.448879, Accuracy: 0.880000, Test accuracy: 0.937300
Epoch : 13, Loss : 6.164934, Accuracy: 0.920000, Test accuracy: 0.939000
Epoch : 14, Loss : 5.904130, Accuracy: 0.880000, Test accuracy: 0.940200
Epoch : 15, Loss : 5.206109, Accuracy: 0.940000, Test accuracy: 0.941200
Epoch : 16, Loss : 4.704682, Accuracy: 0.960000, Test accuracy: 0.942000
Epoch : 17, Loss : 4.707399, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 18, Loss : 4.608377, Accuracy: 0.940000, Test accuracy: 0.944000
Epoch : 19, Loss : 6.394137, Accuracy: 0.900000, Test accuracy: 0.944600
Epoch : 20, Loss : 4.419221, Accuracy: 0.980000, Test accuracy: 0.944900
Epoch : 21, Loss : 4.322970, Accuracy: 0.960000, Test accuracy: 0.946800
Epoch : 22, Loss : 3.958002, Accuracy: 0.960000, Test accuracy: 0.946400
Epoch : 23, Loss : 4.949951, Accuracy: 0.960000, Test accuracy: 0.947600
Epoch : 24, Loss : 5.640293, Accuracy: 0.900000, Test accuracy: 0.947100
Epoch : 25, Loss : 4.615621, Accuracy: 0.940000, Test accuracy: 0.948300
Epoch : 26, Loss : 4.853579, Accuracy: 0.940000, Test accuracy: 0.948600
Epoch : 27, Loss : 4.839081, Accuracy: 0.960000, Test accuracy: 0.949700
Epoch : 28, Loss : 4.525964, Accuracy: 0.940000, Test accuracy: 0.950600
Epoch : 29, Loss : 5.636992, Accuracy: 0.940000, Test accuracy: 0.950700
Epoch : 30, Loss : 4.566214, Accuracy: 0.980000, Test accuracy: 0.951200
Epoch : 31, Loss : 4.846083, Accuracy: 0.960000, Test accuracy: 0.951300
Epoch : 32, Loss : 4.274162, Accuracy: 0.980000, Test accuracy: 0.951700
Epoch : 33, Loss : 4.423202, Accuracy: 0.960000, Test accuracy: 0.951800
Epoch : 34, Loss : 4.516046, Accuracy: 0.940000, Test accuracy: 0.952200
Epoch : 35, Loss : 3.987510, Accuracy: 0.940000, Test accuracy: 0.952900
Epoch : 36, Loss : 4.587525, Accuracy: 0.940000, Test accuracy: 0.953200
Epoch : 37, Loss : 4.149089, Accuracy: 0.960000, Test accuracy: 0.953300
Epoch : 38, Loss : 4.955534, Accuracy: 0.940000, Test accuracy: 0.953900
Epoch : 39, Loss : 5.080862, Accuracy: 0.960000, Test accuracy: 0.954700
Epoch : 40, Loss : 5.033619, Accuracy: 0.900000, Test accuracy: 0.954500
Epoch : 41, Loss : 5.110637, Accuracy: 0.940000, Test accuracy: 0.954100
Epoch : 42, Loss : 5.486012, Accuracy: 0.940000, Test accuracy: 0.954300
Epoch : 43, Loss : 4.117889, Accuracy: 0.980000, Test accuracy: 0.955800
Epoch : 44, Loss : 3.833005, Accuracy: 0.940000, Test accuracy: 0.955900
Epoch : 45, Loss : 4.636988, Accuracy: 0.960000, Test accuracy: 0.954500
Epoch : 46, Loss : 5.074997, Accuracy: 0.940000, Test accuracy: 0.955700
Epoch : 47, Loss : 4.291631, Accuracy: 0.960000, Test accuracy: 0.954800
Epoch : 48, Loss : 4.045475, Accuracy: 0.960000, Test accuracy: 0.956500
Epoch : 49, Loss : 4.960283, Accuracy: 0.920000, Test accuracy: 0.957400
Epoch : 50, Loss : 5.411842, Accuracy: 0.940000, Test accuracy: 0.956300
... completed!

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

可视化 training loss

plt.title("Loss of student")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在这里插入图片描述

可视化一下训练和测试的 accuracy

plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述
看一下某个模型的精度

with tf.Session() as sess:
    saver.restore(sess, "./model_student/-49")
    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))

 
 
  • 1
  • 2
  • 3

output

INFO:tensorflow:Restoring parameters from ./model_student/-49
0.9574

 
 
  • 1
  • 2

保存一下 精度和损失

np.save("loss_student.npy", np.array(losses))
np.save("acc_student.npy", np.array(accs))
np.save("acc_test_student.npy", np.array(test_accs))

 
 
  • 1
  • 2
  • 3
                                </div>
            <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-b6c3c6d139.css" rel="stylesheet">
                </div>
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值