深度学习之卷积:如果卷积核被初始化为0

前言

  这几天面试遇到了这样一个问题,如果卷积层的权重被赋值为0,会发生什么?

  解决这个问题我们首先定义一下在神经网络中的基本权重和偏置的初始化情况,在TensorFlow中,权重一般由用户初始化,可选择很多初始化方式,如glorot_normal_initializer()等,但是偏置在默认的情况下一般初始化为0,具体可以参考tf.layers.conv2dtf.layers.dense两个函数,它们都默认偏置被初始化为0。那么我们也遵守这个设计范式,将卷积核初始化为0,偏置也初始化为0。

分析

  前面的初始化条件十分奇怪,因为将所有的参数初始化为0之后,这一层的输出也必然是0,那么在下一层当中,不管卷积核是什么样的参数分布,经过卷积之后的结果也还是0,由于偏置默认是0,所以输出依然是0,因此这种情况下,整个网络的输出就是0。

  考虑到网络的数据需要在反向传播的过程中进行更新,而计算参数的梯度以及前行传递的误差的过程中需要使用每一层的输入(而大部分层的输入都是0),因此,无法进行梯度的更新。

参数的梯度更新

  考虑某一层的卷积核参数和偏置参数。

  如果这一层位于卷积核为0的层之前,那么,从自动微分的角度来看,当我们选择其中一个参数来进行一个微小的变化时,这个变化在后面的卷积核为0的卷积中就被忽略掉了,那么反映在网络最后的损失函数上也就是没有变化,那么也就是说这个时候该参数的梯度是0,也就是无法更新。

  如果这一层位于卷积核为0的层之后,依然从自动微分的角度来看,当我们选择卷积核中的一个参数进行一个微小的变化的时候,由于该层的输入是0,那么即使我的参数有了一点变化,反映在最后的损失函数上也就是没有变化,也就是说这个时候该参数的梯度是0,所以卷积核的参数无法更新。偏置上一开始对损失函数没有帮助,而在反向传播的过程中需要使用到这一层的输入(是0),因此我们向前传递的误差矩阵都是0,所以梯度根本无法进行更新。

代码

  代码中使用了不变的样本进行训练,目的是为了和其他的初始化方法作比较,属于控制变量的方法。实际上不管采用怎样的模型输入,结果是一致的。

# coding=utf-8
# python3

import tensorflow as tf
from tensorflow.keras.datasets import mnist

import numpy as np


class DataPipeline():
    def __init__(self):
        (image_train, label_train), (image_test, label_test) = mnist.load_data()
        self.image_train = np.array(image_train)
        self.image_test = np.array(image_test)

        self.label_train = np.array(label_train)
        self.label_test = np.array(label_test)

    def next(self, n=1, tag="train"):
        if tag == 'train':
            length = len(self.image_train)
            index = np.random.choice(np.arange(length), n)
            images = self.image_train[index]
            labels = self.label_train[index]
            return np.reshape(images, [n, -1]), labels
        if tag == 'test':
            length = len(self.image_test)
            index = np.random.choice(np.arange(length), n)
            images = self.image_test[index]
            labels = self.label_test[index]
            return np.reshape(images, [n, -1]), labels

    def fixed(self, n=50, tag='train'):
        if tag == 'train':
            images = self.image_train[:n]
            labels = self.label_train[:n]

            return np.reshape(images, [n, -1]), labels
        if tag == 'test':
            images = self.image_test[:n]
            labels = self.label_test[:n]

            return np.reshape(images, [n, -1]), labels


def conv(x):
    tf.set_random_seed(1)

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.glorot_normal_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv1',
                         reuse=tf.AUTO_REUSE)

    x = tf.layers.max_pooling2d(x, 2, 2, 'VALID')

    ##############################
    # This is a baseline where every variable is initialized with a different value
    ##############################

    # x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
    #                      activation=tf.nn.relu,
    #                      use_bias=True,
    #                      kernel_initializer=tf.glorot_normal_initializer(),
    #                      bias_initializer=tf.zeros_initializer(),
    #                      name='conv2',
    #                      reuse=tf.AUTO_REUSE)

    ##############################
    # This is a modification where the kernel and bias arae initialize with 0
    ##############################

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.zeros_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv2',
                         reuse=tf.AUTO_REUSE)
    #

    x = tf.layers.max_pooling2d(x, 2, 2, 'VALID')

    tf.set_random_seed(2)

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.glorot_normal_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv3')

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.glorot_normal_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv4')

    return x


def fc(x):
    x = tf.layers.flatten(x)

    x = tf.layers.dense(x, 128, activation=tf.nn.relu,
                        use_bias=True, 
                        kernel_initializer=tf.glorot_normal_initializer(), 
                        bias_initializer=tf.zeros_initializer(),
                        name='fc1')

    x = tf.layers.dense(x, 10, activation=None,
                        use_bias=True, 
                        kernel_initializer=tf.glorot_normal_initializer(), 
                        bias_initializer=tf.zeros_initializer(),
                        name='fc2')

    return x


def main():
    dataset = DataPipeline()

    image = tf.placeholder(tf.float32, shape=[None, 784])
    label = tf.placeholder(tf.int64, shape=[None])

    x = tf.reshape(image, [-1, 28, 28, 1])

    x = conv(x)

    x = fc(x)

    names = [i.name for i in tf.all_variables()]

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=x))

    train_step = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)

    correct_prediction = tf.equal(tf.argmax(x, 1), label)

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sess = tf.Session()

    sess.run(tf.global_variables_initializer())

    batch = dataset.next(1000, 'test')
    print("test accuracy %g" % sess.run(accuracy, feed_dict={image: batch[0], label: batch[1]}))

    #####################
    # Extrac the all vars
    #####################

    vars = {}

    for i in names:
        vars[i] = sess.run(tf.get_default_graph().get_tensor_by_name(i))

    batch = dataset.fixed(50)
    for i in range(1000):

        if i % 100 == 0:
            train_accuracy = sess.run(accuracy, feed_dict={image: batch[0], label: batch[1]})
            print("step %d,train accuracy %g" % (i, train_accuracy))
            sess.run(train_step, feed_dict={image: batch[0], label: batch[1]})

    batch = dataset.fixed(1000, 'test')
    print("test accuracy %g" % sess.run(accuracy, feed_dict={image: batch[0], label: batch[1]}))

    #####################
    # Extrac the vars again
    #####################

    for i in names:
        vars[i] -= sess.run(tf.get_default_graph().get_tensor_by_name(i))

    for i in names:
        print(i)
        print(vars[i])


if __name__ == '__main__':
    main()

结果
test accuracy 0.11
step 0,train accuracy 0.08
step 100,train accuracy 0.14
step 200,train accuracy 0.14
step 300,train accuracy 0.14
step 400,train accuracy 0.14
step 500,train accuracy 0.14
step 600,train accuracy 0.14
step 700,train accuracy 0.14
step 800,train accuracy 0.14
step 900,train accuracy 0.14
test accuracy 0.126
conv1/kernel:0
[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]


 [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]


 [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]]
conv1/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
conv2/kernel:0
[[[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]]
conv2/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
conv3/kernel:0
[[[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]]
conv3/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
conv4/kernel:0
[[[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]]
conv4/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
fc1/kernel:0
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
fc1/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
fc2/kernel:0
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
fc2/bias:0
[ 1.9999105e-05 -3.9998205e-05  1.9999105e-05 -3.9998202e-05
  1.9999103e-05  1.9999105e-05  1.4156103e-11  1.9999105e-05
  1.9999105e-05 -3.9998198e-05]

  从结果中可以看到,只有最后的偏置项得到了更新,因为这一项是直接和输出和标签label相关的,所以它可以得到更新,向前的数据则统统无法更新,这样的模型的表现力只会有十分有限的提升,无法满足工作生产中的需要。因此,在偏置项默认设置为0的情况下,卷积核和全连接层的权重千万不可以设置为0。

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值