ResNet——总结、实现

转自知乎 Kissrabbit

最近还在忙于毕设,大概需要加深网络的深度,但是,在这个过程中,发现了一个问题,很多时候,一个8层更加复杂的网络的效果远不及一个3层相对简单的网路的效果(当然,也有可能是因为本人调参水平菜的一笔),这然我挺头疼的,因为,网络简单了,学习出来的模型实在是不好使,想复杂一些,加深网络深度,但无奈不仅仅没有效降低train_error,反而还加重了训练的压力,苦恼了好一阵子,,,后来,想起了ResNet,于是去读了读其论文(arxiv.org/pdf/1512.0338),发现,这不正是我所期望的方法嘛!

按照文章的说法,ResNet可以做到在不断加深网络的深度的同时,还能降低train_error,解决了梯度弥散的问题,提高网络的性能,最重要的是,ResNet不仅可以很深,而且,网络的结构还很简单,就是用很单一的小模块堆砌起来的,其单元模块block如下图所示:

如果我们把一个网络的输出定义为 H ,残差定义为 H-x ,那么,ResNet其实就是在学习这个残差 F=H-x ,通过将残差最小化,来学习模型。另外,我们知道在CNN中,输出的维度可以与输入维度不一样,比如池化层、或者卷积核的strides设为2,这就会导致输出 F 与输入 x 的维度不一样,举个例子,输入维度 x=\left[ 10,28,28,1 \right] ,其中10是batch大小,28和28是图片的尺寸,1是图像通道,经过池化或strides=2的卷积,得到 F=\left[ 10,14,14,64 \right] ,这个时候就不能直接把 F,x 直接相加了,针对这种情况,ResNet作者建议可以用1\times1 的卷积层,stride=2,来使得 x=\left[ 10,14,14,64 \right] ,从而与 F 维度匹配起来,再进行相加。

ResNet的基本原理就算说完了,看,是不是很简单,但就是这种很简单的想法却让网络的效果有了质的变化,所以,能带来惊艳效果的,往往都是这些看似平凡的方法。给出文中用上面的block堆砌的一个例子:

图中的虚线表示,这一块,输入和输出的维度不匹配了,需要将输入的维度匹配到输出的维度上,具体做法,就是用上面说到的1x1卷积层来解决。

更多的了解,可以去读一下原文,开头给了网址,有兴趣的可以读一下。


上干货!

ResNet结构是很简单的,所以自己尝试写了一个block的代码resnet.py,如下:

import numpy as np
import tensorflow as tf

slim = tf.contrib.slim
def res_identity(input_tensor, conv_depth, kernel_shape, layer_name):
“”“不改变输入张量的维度”""
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor, conv_depth, kernel_shape))
output_tensor = tf.nn.relu(slim.conv2d(relu, conv_depth, kernel_shape) + input_tensor)

<span class="k" style="font-weight:600;">return</span> <span class="n">output_tensor</span>

def res_change(input_tensor, conv_depth, kernel_shape, layer_name):
“”“改变输入张量的维度”""
input_depth = input_tensor.shape[3]
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor, conv_depth, kernel_shape, stride=2))
input_tensor_reshape = slim.conv2d(input_tensor, conv_depth, [1,1], stride=2) #改变输入的维度,从而保证维度的匹配
output_tensor = tf.nn.relu(slim.conv2d(relu, conv_depth, kernel_shape) + input_tensor_reshape)

<span class="k" style="font-weight:600;">return</span> <span class="n">output_tensor</span>

其中,第一个函数,就是没有图像尺寸没有发生变化的时候,输入与输出的维度是匹配的。第二个函数,就是图像尺寸改变了,这里用的是步长为2的卷积操作,没有用池化层。(并且,使用了slim这么个高层封装,就是为了让代码看起来简洁一些,毕竟太长了,就丑了,大家就不想看了)

有了这么一个模块,就可以搭建自己的ResNet网络,然后在MNIST数据集上进行训练和测试。训练代码mnist_train.py如下:

import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
import win_unicode_console
win_unicode_console.enable()

import mnist_inference

BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARAZTION_RATE = 0.0001
TRAINING_STEPS = 10000
MOVING_AVERAGE_DECAY = 0.99

MODEL_SAVE_PATH = sys.path[0]+"/model/"
MODEL_NAME = “model.ckpt”

def train(mnist):
x = tf.placeholder(tf.float32, [None, 28281], name=‘x-input’)
y_ = tf.placeholder(tf.float32, [None, 10], name=‘y-input’)
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)

<span class="n">y</span> <span class="o" style="font-weight:600;">=</span> <span class="n">mnist_inference</span><span class="o" style="font-weight:600;">.</span><span class="n">inference</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">global_step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Variable</span><span class="p">(</span><span class="mi" style="color:rgb(0,132,255);">0</span><span class="p">,</span> <span class="n">trainable</span> <span class="o" style="font-weight:600;">=</span> <span class="bp" style="color:rgb(153,153,153);">False</span><span class="p">)</span>

<span class="n">cross_entropy</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">nn</span><span class="o" style="font-weight:600;">.</span><span class="n">sparse_softmax_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o" style="font-weight:600;">=</span><span class="n">y</span><span class="p">,</span> <span class="n">labels</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">))</span>
<span class="n">cross_entropy_mean</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">cross_entropy</span><span class="p">)</span>
<span class="n">loss</span> <span class="o" style="font-weight:600;">=</span> <span class="n">cross_entropy_mean</span>
<span class="n">learning_rate</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">exponential_decay</span><span class="p">(</span><span class="n">LEARNING_RATE_BASE</span><span class="p">,</span><span class="n">global_step</span><span class="p">,</span><span class="n">mnist</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">num_examples</span> <span class="o" style="font-weight:600;">/</span> <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">LEARNING_RATE_DECAY</span><span class="p">)</span>
<span class="n">train_step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">AdadeltaOptimizer</span><span class="p">(</span><span class="n">learning_rate</span><span class="p">)</span><span class="o" style="font-weight:600;">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">global_step</span><span class="o" style="font-weight:600;">=</span><span class="n">global_step</span><span class="p">)</span>

<span class="c1" style="font-style:italic;color:rgb(153,153,153);">#初始化Tensorflow持久化类</span>
<span class="n">saver</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">Saver</span><span class="p">()</span>
<span class="k" style="font-weight:600;">with</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Session</span><span class="p">()</span> <span class="k" style="font-weight:600;">as</span> <span class="n">sess</span><span class="p">:</span>
    <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">global_variables_initializer</span><span class="p">()</span><span class="o" style="font-weight:600;">.</span><span class="n">run</span><span class="p">()</span>
    <span class="c1" style="font-style:italic;color:rgb(153,153,153);">#在训练过程中不再测试模型在验证数据上的表现,验证和测试的过程将会有一个独立的程序来完成</span>
    <span class="k" style="font-weight:600;">for</span> <span class="n">i</span> <span class="ow" style="font-weight:600;">in</span> <span class="nb" style="color:rgb(0,132,255);">range</span><span class="p">(</span><span class="n">TRAINING_STEPS</span><span class="p">):</span>
        <span class="n">xs</span><span class="p">,</span> <span class="n">ys</span> <span class="o" style="font-weight:600;">=</span> <span class="n">mnist</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">next_batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span>
        <span class="n">_</span><span class="p">,</span> <span class="n">loss_value</span><span class="p">,</span> <span class="n">step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">sess</span><span class="o" style="font-weight:600;">.</span><span class="n">run</span><span class="p">([</span><span class="n">train_step</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">global_step</span><span class="p">],</span> <span class="n">feed_dict</span><span class="o" style="font-weight:600;">=</span><span class="p">{</span><span class="n">x</span><span class="p">:</span><span class="n">xs</span><span class="p">,</span> <span class="n">y_</span><span class="p">:</span><span class="n">ys</span><span class="p">})</span>

        <span class="k" style="font-weight:600;">if</span> <span class="n">i</span> <span class="o" style="font-weight:600;">%</span> <span class="mi" style="color:rgb(0,132,255);">100</span> <span class="o" style="font-weight:600;">==</span> <span class="mi" style="color:rgb(0,132,255);">0</span><span class="p">:</span>
             <span class="k" style="font-weight:600;">print</span><span class="p">(</span><span class="s2" style="color:rgb(241,64,60);">"After </span><span class="si" style="color:rgb(241,64,60);">%d</span><span class="s2" style="color:rgb(241,64,60);"> training step(s), loss on training batch is </span><span class="si" style="color:rgb(241,64,60);">%g</span><span class="s2" style="color:rgb(241,64,60);"> "</span> <span class="o" style="font-weight:600;">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">loss_value</span><span class="p">))</span>
    <span class="n">saver</span><span class="o" style="font-weight:600;">.</span><span class="n">save</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">MODEL_SAVE_PATH</span><span class="o" style="font-weight:600;">+</span><span class="n">MODEL_NAME</span><span class="p">,</span> <span class="n">global_step</span><span class="o" style="font-weight:600;">=</span><span class="n">global_step</span><span class="p">)</span>

def main(argv=None):
mnist = input_data.read_data_sets(“D:/work/mnist/”, one_hot = True) #这里换上你自己的数据集路径
train(mnist)

if name == main:
tf.app.run()

前向传播的代码mnist_inference.py:

import tensorflow as tf
from resnet import *

slim = tf.contrib.slim
def inference(input_tensor):
x_image = tf.reshape(input_tensor, [-1,28,28,1])
relu_1 = tf.nn.relu(slim.conv2d(x_image, 32, [3,3]))
pool_1 = slim.max_pool2d(relu_1, [2,2])
net = res_identity(pool_1, 32, [3,3], ‘layer_2’)
net = res_identity(net, 32, [3,3], ‘layer_3’)

<span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">slim</span><span class="o" style="font-weight:600;">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">scope</span><span class="o" style="font-weight:600;">=</span><span class="s1" style="color:rgb(241,64,60);">'flatten'</span><span class="p">)</span>
<span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">slim</span><span class="o" style="font-weight:600;">.</span><span class="n">fully_connected</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">10</span><span class="p">,</span> <span class="n">scope</span><span class="o" style="font-weight:600;">=</span><span class="s1" style="color:rgb(241,64,60);">'output'</span><span class="p">)</span>

<span class="k" style="font-weight:600;">return</span> <span class="n">net</span>

就只加了两层resnet,有兴趣的可以多加点,加他个20层也没问题。

测试代码mnist_eval.py:

import tensorflow as tf
import time
from tensorflow.examples.tutorials.mnist import input_data
import win_unicode_console
win_unicode_console.enable()
import mnist_inference
import mnist_train

def evaluate(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, 28281], name=‘x-input’)
y_ = tf.placeholder(tf.float32, [None, 10], name=‘y-input’)
validate_feed = {x: mnist.validation.images, y_:mnist.validation.labels}

    <span class="n">y</span> <span class="o" style="font-weight:600;">=</span> <span class="n">mnist_inference</span><span class="o" style="font-weight:600;">.</span><span class="n">inference</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="n">correct_prediction</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">equal</span><span class="p">(</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">),</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">))</span>
    <span class="n">accuracy</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">cast</span><span class="p">(</span><span class="n">correct_prediction</span><span class="p">,</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">float32</span><span class="p">))</span>

    <span class="n">saver</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">Saver</span><span class="p">()</span>
    <span class="k" style="font-weight:600;">with</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Session</span><span class="p">()</span> <span class="k" style="font-weight:600;">as</span> <span class="n">sess</span><span class="p">:</span>
        <span class="n">ckpt</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">get_checkpoint_state</span><span class="p">(</span><span class="n">mnist_train</span><span class="o" style="font-weight:600;">.</span><span class="n">MODEL_SAVE_PATH</span><span class="p">)</span>
        <span class="k" style="font-weight:600;">if</span> <span class="n">ckpt</span> <span class="ow" style="font-weight:600;">and</span> <span class="n">ckpt</span><span class="o" style="font-weight:600;">.</span><span class="n">model_checkpoint_path</span><span class="p">:</span>
            <span class="n">saver</span><span class="o" style="font-weight:600;">.</span><span class="n">restore</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">ckpt</span><span class="o" style="font-weight:600;">.</span><span class="n">model_checkpoint_path</span><span class="p">)</span>
            <span class="n">global_step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">ckpt</span><span class="o" style="font-weight:600;">.</span><span class="n">model_checkpoint_path</span><span class="o" style="font-weight:600;">.</span><span class="n">split</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'/'</span><span class="p">)[</span><span class="o" style="font-weight:600;">-</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">]</span><span class="o" style="font-weight:600;">.</span><span class="n">split</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'-'</span><span class="p">)[</span><span class="o" style="font-weight:600;">-</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">]</span>
            <span class="n">accuracy_score</span> <span class="o" style="font-weight:600;">=</span> <span class="n">sess</span><span class="o" style="font-weight:600;">.</span><span class="n">run</span><span class="p">(</span><span class="n">accuracy</span><span class="p">,</span> <span class="n">feed_dict</span><span class="o" style="font-weight:600;">=</span><span class="n">validate_feed</span><span class="p">)</span>
            <span class="k" style="font-weight:600;">print</span><span class="p">(</span><span class="s2" style="color:rgb(241,64,60);">"validation accuracy - </span><span class="si" style="color:rgb(241,64,60);">%g</span><span class="s2" style="color:rgb(241,64,60);">"</span> <span class="o" style="font-weight:600;">%</span> <span class="n">accuracy_score</span><span class="p">)</span>
        <span class="k" style="font-weight:600;">else</span><span class="p">:</span>
            <span class="k" style="font-weight:600;">print</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'No checkpoing file found'</span><span class="p">)</span>
            <span class="k" style="font-weight:600;">return</span>

def main(argv=None):
mnist = input_data.read_data_sets(“D:/work/mnist/”, one_hot=True) #这里换上自己的数据集路径
evaluate(mnist)

if name == main:
tf.app.run()

测试的结果为0.992,准确度挺好,毕竟MNIST这个数据集实在太简单了,如果想体现出自己算法的优越性,这个数据集是完全不行的。


总的来说,自己动手实现ResNet还是很容易的,毕竟这个方法的思想还是很简单的,但就是这么简单的思想却带来了巨大的改变,实在是佩服作者的脑洞!对笔者的block的代码resnet.py觉得有问题的,可以在评论区交流,欢迎大佬指正,本人菜鸡,还望轻喷


更新2018-4-24:上面是之前给的ResNet模块的代码,前一阵子又改了改,主要的改动就是在卷积层后面都加上了BatchNormalization(批归一化,BN)处理,也更好地解决提督弥散问题,关于BN的知识,随便百度一下就有很多相关博客文章,这里就不赘述了。

先把干货端上来:

import numpy as np
import tensorflow as tf
from tensorlayer.layers import *

slim = tf.contrib.slim
def res_identity(net0, conv_depth, kernel_shape, layer_name, train):
“”“不改变输入张量的维度”""
gamma_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope(layer_name):
net = Conv2d(net0, conv_depth, kernel_shape, b_init=None, name=layer_name+’/conv_1’)
bn_1 = BatchNormLayer(net, act=tf.nn.relu, is_train= train, gamma_init=gamma_init, name=layer_name+’/bn_1’)

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">bn_1</span><span class="p">,</span> <span class="n">conv_depth</span><span class="p">,</span> <span class="n">kernel_shape</span><span class="p">,</span> <span class="n">b_init</span><span class="o" style="font-weight:600;">=</span><span class="bp" style="color:rgb(153,153,153);">None</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/conv_2'</span><span class="p">)</span>
    <span class="n">bn_2</span> <span class="o" style="font-weight:600;">=</span> <span class="n">BatchNormLayer</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">act</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">identity</span><span class="p">,</span> <span class="n">is_train</span><span class="o" style="font-weight:600;">=</span> <span class="n">train</span><span class="p">,</span> <span class="n">gamma_init</span><span class="o" style="font-weight:600;">=</span><span class="n">gamma_init</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/bn_2'</span><span class="p">)</span>

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">ElementwiseLayer</span><span class="p">(</span><span class="n">layer</span><span class="o" style="font-weight:600;">=</span><span class="p">[</span><span class="n">bn_2</span><span class="p">,</span><span class="n">net0</span><span class="p">],</span> <span class="n">combine_fn</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">add</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/add'</span><span class="p">)</span>
    <span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">nn</span><span class="o" style="font-weight:600;">.</span><span class="n">relu</span><span class="p">(</span><span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="k" style="font-weight:600;">return</span> <span class="n">net</span>

def res_change(net0, conv_depth, kernel_shape, layer_name, train):
“”“改变输入张量的维度”""
gamma_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope(layer_name):
net = Conv2d(net0, conv_depth, kernel_shape, strides=(2,2), b_init=None, name=layer_name+’/conv_1’)
bn_1 = BatchNormLayer(net, act=tf.nn.relu, is_train= train, gamma_init=gamma_init, name=layer_name+’/bn_1’)

    <span class="n">net0_reshape</span> <span class="o" style="font-weight:600;">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">net0</span><span class="p">,</span> <span class="n">conv_depth</span><span class="p">,</span> <span class="p">(</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">,</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">),</span> <span class="n">strides</span><span class="o" style="font-weight:600;">=</span><span class="p">(</span><span class="mi" style="color:rgb(0,132,255);">2</span><span class="p">,</span><span class="mi" style="color:rgb(0,132,255);">2</span><span class="p">),</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/conv_2'</span><span class="p">)</span>
    <span class="n">bn_2</span> <span class="o" style="font-weight:600;">=</span> <span class="n">BatchNormLayer</span><span class="p">(</span><span class="n">net0_reshape</span><span class="p">,</span> <span class="n">act</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">identity</span><span class="p">,</span> <span class="n">is_train</span><span class="o" style="font-weight:600;">=</span> <span class="n">train</span><span class="p">,</span> <span class="n">gamma_init</span><span class="o" style="font-weight:600;">=</span><span class="n">gamma_init</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/bn_2'</span><span class="p">)</span>

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">bn_1</span><span class="p">,</span> <span class="n">conv_depth</span><span class="p">,</span> <span class="n">kernel_shape</span><span class="p">,</span> <span class="n">b_init</span><span class="o" style="font-weight:600;">=</span><span class="bp" style="color:rgb(153,153,153);">None</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/conv_3'</span><span class="p">)</span>
    <span class="n">bn_3</span> <span class="o" style="font-weight:600;">=</span> <span class="n">BatchNormLayer</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">act</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">identity</span><span class="p">,</span> <span class="n">is_train</span><span class="o" style="font-weight:600;">=</span> <span class="n">train</span><span class="p">,</span> <span class="n">gamma_init</span><span class="o" style="font-weight:600;">=</span><span class="n">gamma_init</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/bn_3'</span><span class="p">)</span>

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">ElementwiseLayer</span><span class="p">(</span><span class="n">layer</span><span class="o" style="font-weight:600;">=</span><span class="p">[</span><span class="n">bn_3</span><span class="p">,</span><span class="n">bn_2</span><span class="p">],</span> <span class="n">combine_fn</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">add</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/add'</span><span class="p">)</span>
    <span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">nn</span><span class="o" style="font-weight:600;">.</span><span class="n">relu</span><span class="p">(</span><span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="k" style="font-weight:600;">return</span> <span class="n">net</span>

这里,用上了tensorlayer,tf的一种高层封装,所以,想用这个代码,需要装tensorlayer,非常好装,命令行中输入pip install tensorlayer 就OK了(用tensorlayer写BN非常容易,一行就OK了,要是用tf的话,稍微就麻烦了,代码的可读性就不太好了)。

ResNet’在图像分类问题上,真的是厉害得不得了!不亲自体验一下,都不敢相信!传统的方法搭建8层卷积,不仅速度慢,还很难收敛,或者收敛得非常慢,迭代次数也非常多,但是用ResNet的话,效果提升极其显著,关键是,可能传统卷积层需要10万次才收敛到比较好的值,对于ResNet可能就只需要5000次(我前一阵子就是做了一个这样的对比,由于不方便公开,,大家就自己找个大点的图像做一下分类任务,亲自体验一下就知道了),所以,ResNet在分类问题上,真的是牛( ఠൠఠ )ノ(DenseNet我倒是也写了,不过,感觉要不是任务太复杂了,其实ResNet就很OK了,本人没有做过实验对比,所以就不做过多评价了)。

希望,本文能对大家有所帮助!

最近还在忙于毕设,大概需要加深网络的深度,但是,在这个过程中,发现了一个问题,很多时候,一个8层更加复杂的网络的效果远不及一个3层相对简单的网路的效果(当然,也有可能是因为本人调参水平菜的一笔),这然我挺头疼的,因为,网络简单了,学习出来的模型实在是不好使,想复杂一些,加深网络深度,但无奈不仅仅没有效降低train_error,反而还加重了训练的压力,苦恼了好一阵子,,,后来,想起了ResNet,于是去读了读其论文(arxiv.org/pdf/1512.0338),发现,这不正是我所期望的方法嘛!

按照文章的说法,ResNet可以做到在不断加深网络的深度的同时,还能降低train_error,解决了梯度弥散的问题,提高网络的性能,最重要的是,ResNet不仅可以很深,而且,网络的结构还很简单,就是用很单一的小模块堆砌起来的,其单元模块block如下图所示:

如果我们把一个网络的输出定义为 H ,残差定义为 H-x ,那么,ResNet其实就是在学习这个残差 F=H-x ,通过将残差最小化,来学习模型。另外,我们知道在CNN中,输出的维度可以与输入维度不一样,比如池化层、或者卷积核的strides设为2,这就会导致输出 F 与输入 x 的维度不一样,举个例子,输入维度 x=\left[ 10,28,28,1 \right] ,其中10是batch大小,28和28是图片的尺寸,1是图像通道,经过池化或strides=2的卷积,得到 F=\left[ 10,14,14,64 \right] ,这个时候就不能直接把 F,x 直接相加了,针对这种情况,ResNet作者建议可以用1\times1 的卷积层,stride=2,来使得 x=\left[ 10,14,14,64 \right] ,从而与 F 维度匹配起来,再进行相加。

ResNet的基本原理就算说完了,看,是不是很简单,但就是这种很简单的想法却让网络的效果有了质的变化,所以,能带来惊艳效果的,往往都是这些看似平凡的方法。给出文中用上面的block堆砌的一个例子:

图中的虚线表示,这一块,输入和输出的维度不匹配了,需要将输入的维度匹配到输出的维度上,具体做法,就是用上面说到的1x1卷积层来解决。

更多的了解,可以去读一下原文,开头给了网址,有兴趣的可以读一下。


上干货!

ResNet结构是很简单的,所以自己尝试写了一个block的代码resnet.py,如下:

import numpy as np
import tensorflow as tf

slim = tf.contrib.slim
def res_identity(input_tensor, conv_depth, kernel_shape, layer_name):
“”“不改变输入张量的维度”""
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor, conv_depth, kernel_shape))
output_tensor = tf.nn.relu(slim.conv2d(relu, conv_depth, kernel_shape) + input_tensor)

<span class="k" style="font-weight:600;">return</span> <span class="n">output_tensor</span>

def res_change(input_tensor, conv_depth, kernel_shape, layer_name):
“”“改变输入张量的维度”""
input_depth = input_tensor.shape[3]
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor, conv_depth, kernel_shape, stride=2))
input_tensor_reshape = slim.conv2d(input_tensor, conv_depth, [1,1], stride=2) #改变输入的维度,从而保证维度的匹配
output_tensor = tf.nn.relu(slim.conv2d(relu, conv_depth, kernel_shape) + input_tensor_reshape)

<span class="k" style="font-weight:600;">return</span> <span class="n">output_tensor</span>

其中,第一个函数,就是没有图像尺寸没有发生变化的时候,输入与输出的维度是匹配的。第二个函数,就是图像尺寸改变了,这里用的是步长为2的卷积操作,没有用池化层。(并且,使用了slim这么个高层封装,就是为了让代码看起来简洁一些,毕竟太长了,就丑了,大家就不想看了)

有了这么一个模块,就可以搭建自己的ResNet网络,然后在MNIST数据集上进行训练和测试。训练代码mnist_train.py如下:

import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
import win_unicode_console
win_unicode_console.enable()

import mnist_inference

BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARAZTION_RATE = 0.0001
TRAINING_STEPS = 10000
MOVING_AVERAGE_DECAY = 0.99

MODEL_SAVE_PATH = sys.path[0]+"/model/"
MODEL_NAME = “model.ckpt”

def train(mnist):
x = tf.placeholder(tf.float32, [None, 28281], name=‘x-input’)
y_ = tf.placeholder(tf.float32, [None, 10], name=‘y-input’)
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)

<span class="n">y</span> <span class="o" style="font-weight:600;">=</span> <span class="n">mnist_inference</span><span class="o" style="font-weight:600;">.</span><span class="n">inference</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">global_step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Variable</span><span class="p">(</span><span class="mi" style="color:rgb(0,132,255);">0</span><span class="p">,</span> <span class="n">trainable</span> <span class="o" style="font-weight:600;">=</span> <span class="bp" style="color:rgb(153,153,153);">False</span><span class="p">)</span>

<span class="n">cross_entropy</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">nn</span><span class="o" style="font-weight:600;">.</span><span class="n">sparse_softmax_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o" style="font-weight:600;">=</span><span class="n">y</span><span class="p">,</span> <span class="n">labels</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">))</span>
<span class="n">cross_entropy_mean</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">cross_entropy</span><span class="p">)</span>
<span class="n">loss</span> <span class="o" style="font-weight:600;">=</span> <span class="n">cross_entropy_mean</span>
<span class="n">learning_rate</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">exponential_decay</span><span class="p">(</span><span class="n">LEARNING_RATE_BASE</span><span class="p">,</span><span class="n">global_step</span><span class="p">,</span><span class="n">mnist</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">num_examples</span> <span class="o" style="font-weight:600;">/</span> <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">LEARNING_RATE_DECAY</span><span class="p">)</span>
<span class="n">train_step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">AdadeltaOptimizer</span><span class="p">(</span><span class="n">learning_rate</span><span class="p">)</span><span class="o" style="font-weight:600;">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">global_step</span><span class="o" style="font-weight:600;">=</span><span class="n">global_step</span><span class="p">)</span>

<span class="c1" style="font-style:italic;color:rgb(153,153,153);">#初始化Tensorflow持久化类</span>
<span class="n">saver</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">Saver</span><span class="p">()</span>
<span class="k" style="font-weight:600;">with</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Session</span><span class="p">()</span> <span class="k" style="font-weight:600;">as</span> <span class="n">sess</span><span class="p">:</span>
    <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">global_variables_initializer</span><span class="p">()</span><span class="o" style="font-weight:600;">.</span><span class="n">run</span><span class="p">()</span>
    <span class="c1" style="font-style:italic;color:rgb(153,153,153);">#在训练过程中不再测试模型在验证数据上的表现,验证和测试的过程将会有一个独立的程序来完成</span>
    <span class="k" style="font-weight:600;">for</span> <span class="n">i</span> <span class="ow" style="font-weight:600;">in</span> <span class="nb" style="color:rgb(0,132,255);">range</span><span class="p">(</span><span class="n">TRAINING_STEPS</span><span class="p">):</span>
        <span class="n">xs</span><span class="p">,</span> <span class="n">ys</span> <span class="o" style="font-weight:600;">=</span> <span class="n">mnist</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">next_batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span>
        <span class="n">_</span><span class="p">,</span> <span class="n">loss_value</span><span class="p">,</span> <span class="n">step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">sess</span><span class="o" style="font-weight:600;">.</span><span class="n">run</span><span class="p">([</span><span class="n">train_step</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">global_step</span><span class="p">],</span> <span class="n">feed_dict</span><span class="o" style="font-weight:600;">=</span><span class="p">{</span><span class="n">x</span><span class="p">:</span><span class="n">xs</span><span class="p">,</span> <span class="n">y_</span><span class="p">:</span><span class="n">ys</span><span class="p">})</span>

        <span class="k" style="font-weight:600;">if</span> <span class="n">i</span> <span class="o" style="font-weight:600;">%</span> <span class="mi" style="color:rgb(0,132,255);">100</span> <span class="o" style="font-weight:600;">==</span> <span class="mi" style="color:rgb(0,132,255);">0</span><span class="p">:</span>
             <span class="k" style="font-weight:600;">print</span><span class="p">(</span><span class="s2" style="color:rgb(241,64,60);">"After </span><span class="si" style="color:rgb(241,64,60);">%d</span><span class="s2" style="color:rgb(241,64,60);"> training step(s), loss on training batch is </span><span class="si" style="color:rgb(241,64,60);">%g</span><span class="s2" style="color:rgb(241,64,60);"> "</span> <span class="o" style="font-weight:600;">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">loss_value</span><span class="p">))</span>
    <span class="n">saver</span><span class="o" style="font-weight:600;">.</span><span class="n">save</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">MODEL_SAVE_PATH</span><span class="o" style="font-weight:600;">+</span><span class="n">MODEL_NAME</span><span class="p">,</span> <span class="n">global_step</span><span class="o" style="font-weight:600;">=</span><span class="n">global_step</span><span class="p">)</span>

def main(argv=None):
mnist = input_data.read_data_sets(“D:/work/mnist/”, one_hot = True) #这里换上你自己的数据集路径
train(mnist)

if name == main:
tf.app.run()

前向传播的代码mnist_inference.py:

import tensorflow as tf
from resnet import *

slim = tf.contrib.slim
def inference(input_tensor):
x_image = tf.reshape(input_tensor, [-1,28,28,1])
relu_1 = tf.nn.relu(slim.conv2d(x_image, 32, [3,3]))
pool_1 = slim.max_pool2d(relu_1, [2,2])
net = res_identity(pool_1, 32, [3,3], ‘layer_2’)
net = res_identity(net, 32, [3,3], ‘layer_3’)

<span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">slim</span><span class="o" style="font-weight:600;">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">scope</span><span class="o" style="font-weight:600;">=</span><span class="s1" style="color:rgb(241,64,60);">'flatten'</span><span class="p">)</span>
<span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">slim</span><span class="o" style="font-weight:600;">.</span><span class="n">fully_connected</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">10</span><span class="p">,</span> <span class="n">scope</span><span class="o" style="font-weight:600;">=</span><span class="s1" style="color:rgb(241,64,60);">'output'</span><span class="p">)</span>

<span class="k" style="font-weight:600;">return</span> <span class="n">net</span>

就只加了两层resnet,有兴趣的可以多加点,加他个20层也没问题。

测试代码mnist_eval.py:

import tensorflow as tf
import time
from tensorflow.examples.tutorials.mnist import input_data
import win_unicode_console
win_unicode_console.enable()
import mnist_inference
import mnist_train

def evaluate(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, 28281], name=‘x-input’)
y_ = tf.placeholder(tf.float32, [None, 10], name=‘y-input’)
validate_feed = {x: mnist.validation.images, y_:mnist.validation.labels}

    <span class="n">y</span> <span class="o" style="font-weight:600;">=</span> <span class="n">mnist_inference</span><span class="o" style="font-weight:600;">.</span><span class="n">inference</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="n">correct_prediction</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">equal</span><span class="p">(</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">),</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_</span><span class="p">,</span> <span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">))</span>
    <span class="n">accuracy</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">cast</span><span class="p">(</span><span class="n">correct_prediction</span><span class="p">,</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">float32</span><span class="p">))</span>

    <span class="n">saver</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">Saver</span><span class="p">()</span>
    <span class="k" style="font-weight:600;">with</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Session</span><span class="p">()</span> <span class="k" style="font-weight:600;">as</span> <span class="n">sess</span><span class="p">:</span>
        <span class="n">ckpt</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">get_checkpoint_state</span><span class="p">(</span><span class="n">mnist_train</span><span class="o" style="font-weight:600;">.</span><span class="n">MODEL_SAVE_PATH</span><span class="p">)</span>
        <span class="k" style="font-weight:600;">if</span> <span class="n">ckpt</span> <span class="ow" style="font-weight:600;">and</span> <span class="n">ckpt</span><span class="o" style="font-weight:600;">.</span><span class="n">model_checkpoint_path</span><span class="p">:</span>
            <span class="n">saver</span><span class="o" style="font-weight:600;">.</span><span class="n">restore</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">ckpt</span><span class="o" style="font-weight:600;">.</span><span class="n">model_checkpoint_path</span><span class="p">)</span>
            <span class="n">global_step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">ckpt</span><span class="o" style="font-weight:600;">.</span><span class="n">model_checkpoint_path</span><span class="o" style="font-weight:600;">.</span><span class="n">split</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'/'</span><span class="p">)[</span><span class="o" style="font-weight:600;">-</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">]</span><span class="o" style="font-weight:600;">.</span><span class="n">split</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'-'</span><span class="p">)[</span><span class="o" style="font-weight:600;">-</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">]</span>
            <span class="n">accuracy_score</span> <span class="o" style="font-weight:600;">=</span> <span class="n">sess</span><span class="o" style="font-weight:600;">.</span><span class="n">run</span><span class="p">(</span><span class="n">accuracy</span><span class="p">,</span> <span class="n">feed_dict</span><span class="o" style="font-weight:600;">=</span><span class="n">validate_feed</span><span class="p">)</span>
            <span class="k" style="font-weight:600;">print</span><span class="p">(</span><span class="s2" style="color:rgb(241,64,60);">"validation accuracy - </span><span class="si" style="color:rgb(241,64,60);">%g</span><span class="s2" style="color:rgb(241,64,60);">"</span> <span class="o" style="font-weight:600;">%</span> <span class="n">accuracy_score</span><span class="p">)</span>
        <span class="k" style="font-weight:600;">else</span><span class="p">:</span>
            <span class="k" style="font-weight:600;">print</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'No checkpoing file found'</span><span class="p">)</span>
            <span class="k" style="font-weight:600;">return</span>

def main(argv=None):
mnist = input_data.read_data_sets(“D:/work/mnist/”, one_hot=True) #这里换上自己的数据集路径
evaluate(mnist)

if name == main:
tf.app.run()

测试的结果为0.992,准确度挺好,毕竟MNIST这个数据集实在太简单了,如果想体现出自己算法的优越性,这个数据集是完全不行的。


总的来说,自己动手实现ResNet还是很容易的,毕竟这个方法的思想还是很简单的,但就是这么简单的思想却带来了巨大的改变,实在是佩服作者的脑洞!对笔者的block的代码resnet.py觉得有问题的,可以在评论区交流,欢迎大佬指正,本人菜鸡,还望轻喷


更新2018-4-24:上面是之前给的ResNet模块的代码,前一阵子又改了改,主要的改动就是在卷积层后面都加上了BatchNormalization(批归一化,BN)处理,也更好地解决提督弥散问题,关于BN的知识,随便百度一下就有很多相关博客文章,这里就不赘述了。

先把干货端上来:

import numpy as np
import tensorflow as tf
from tensorlayer.layers import *

slim = tf.contrib.slim
def res_identity(net0, conv_depth, kernel_shape, layer_name, train):
“”“不改变输入张量的维度”""
gamma_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope(layer_name):
net = Conv2d(net0, conv_depth, kernel_shape, b_init=None, name=layer_name+’/conv_1’)
bn_1 = BatchNormLayer(net, act=tf.nn.relu, is_train= train, gamma_init=gamma_init, name=layer_name+’/bn_1’)

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">bn_1</span><span class="p">,</span> <span class="n">conv_depth</span><span class="p">,</span> <span class="n">kernel_shape</span><span class="p">,</span> <span class="n">b_init</span><span class="o" style="font-weight:600;">=</span><span class="bp" style="color:rgb(153,153,153);">None</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/conv_2'</span><span class="p">)</span>
    <span class="n">bn_2</span> <span class="o" style="font-weight:600;">=</span> <span class="n">BatchNormLayer</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">act</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">identity</span><span class="p">,</span> <span class="n">is_train</span><span class="o" style="font-weight:600;">=</span> <span class="n">train</span><span class="p">,</span> <span class="n">gamma_init</span><span class="o" style="font-weight:600;">=</span><span class="n">gamma_init</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/bn_2'</span><span class="p">)</span>

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">ElementwiseLayer</span><span class="p">(</span><span class="n">layer</span><span class="o" style="font-weight:600;">=</span><span class="p">[</span><span class="n">bn_2</span><span class="p">,</span><span class="n">net0</span><span class="p">],</span> <span class="n">combine_fn</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">add</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/add'</span><span class="p">)</span>
    <span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">nn</span><span class="o" style="font-weight:600;">.</span><span class="n">relu</span><span class="p">(</span><span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="k" style="font-weight:600;">return</span> <span class="n">net</span>

def res_change(net0, conv_depth, kernel_shape, layer_name, train):
“”“改变输入张量的维度”""
gamma_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope(layer_name):
net = Conv2d(net0, conv_depth, kernel_shape, strides=(2,2), b_init=None, name=layer_name+’/conv_1’)
bn_1 = BatchNormLayer(net, act=tf.nn.relu, is_train= train, gamma_init=gamma_init, name=layer_name+’/bn_1’)

    <span class="n">net0_reshape</span> <span class="o" style="font-weight:600;">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">net0</span><span class="p">,</span> <span class="n">conv_depth</span><span class="p">,</span> <span class="p">(</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">,</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">),</span> <span class="n">strides</span><span class="o" style="font-weight:600;">=</span><span class="p">(</span><span class="mi" style="color:rgb(0,132,255);">2</span><span class="p">,</span><span class="mi" style="color:rgb(0,132,255);">2</span><span class="p">),</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/conv_2'</span><span class="p">)</span>
    <span class="n">bn_2</span> <span class="o" style="font-weight:600;">=</span> <span class="n">BatchNormLayer</span><span class="p">(</span><span class="n">net0_reshape</span><span class="p">,</span> <span class="n">act</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">identity</span><span class="p">,</span> <span class="n">is_train</span><span class="o" style="font-weight:600;">=</span> <span class="n">train</span><span class="p">,</span> <span class="n">gamma_init</span><span class="o" style="font-weight:600;">=</span><span class="n">gamma_init</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/bn_2'</span><span class="p">)</span>

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">bn_1</span><span class="p">,</span> <span class="n">conv_depth</span><span class="p">,</span> <span class="n">kernel_shape</span><span class="p">,</span> <span class="n">b_init</span><span class="o" style="font-weight:600;">=</span><span class="bp" style="color:rgb(153,153,153);">None</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/conv_3'</span><span class="p">)</span>
    <span class="n">bn_3</span> <span class="o" style="font-weight:600;">=</span> <span class="n">BatchNormLayer</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">act</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">identity</span><span class="p">,</span> <span class="n">is_train</span><span class="o" style="font-weight:600;">=</span> <span class="n">train</span><span class="p">,</span> <span class="n">gamma_init</span><span class="o" style="font-weight:600;">=</span><span class="n">gamma_init</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/bn_3'</span><span class="p">)</span>

    <span class="n">net</span> <span class="o" style="font-weight:600;">=</span> <span class="n">ElementwiseLayer</span><span class="p">(</span><span class="n">layer</span><span class="o" style="font-weight:600;">=</span><span class="p">[</span><span class="n">bn_3</span><span class="p">,</span><span class="n">bn_2</span><span class="p">],</span> <span class="n">combine_fn</span><span class="o" style="font-weight:600;">=</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">add</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">layer_name</span><span class="o" style="font-weight:600;">+</span><span class="s1" style="color:rgb(241,64,60);">'/add'</span><span class="p">)</span>
    <span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">nn</span><span class="o" style="font-weight:600;">.</span><span class="n">relu</span><span class="p">(</span><span class="n">net</span><span class="o" style="font-weight:600;">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="k" style="font-weight:600;">return</span> <span class="n">net</span>

这里,用上了tensorlayer,tf的一种高层封装,所以,想用这个代码,需要装tensorlayer,非常好装,命令行中输入pip install tensorlayer 就OK了(用tensorlayer写BN非常容易,一行就OK了,要是用tf的话,稍微就麻烦了,代码的可读性就不太好了)。

ResNet’在图像分类问题上,真的是厉害得不得了!不亲自体验一下,都不敢相信!传统的方法搭建8层卷积,不仅速度慢,还很难收敛,或者收敛得非常慢,迭代次数也非常多,但是用ResNet的话,效果提升极其显著,关键是,可能传统卷积层需要10万次才收敛到比较好的值,对于ResNet可能就只需要5000次(我前一阵子就是做了一个这样的对比,由于不方便公开,,大家就自己找个大点的图像做一下分类任务,亲自体验一下就知道了),所以,ResNet在分类问题上,真的是牛( ఠൠఠ )ノ(DenseNet我倒是也写了,不过,感觉要不是任务太复杂了,其实ResNet就很OK了,本人没有做过实验对比,所以就不做过多评价了)。

希望,本文能对大家有所帮助!


  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值