深度学习的resnet的原理比较容易懂,但是具体的代码构建还是有些复杂,尤其是在tensorflow 的框架下构建比较复杂,在今天的博文中主要介绍了残差单元的构建:
import collections
import tensorflow as tf
import tensorflow.contrib.slim as slim
def subsample(inputs,factor,scope=None):
if(factor==1):
return inputs
else:
return slim.max_pool2d(inputs,[1,1],stride=factor,scope=scope)
上述这段代码的目的是为了构建一个降采样的函数,目的是为了使输入通道相同的preact的shape经过池化,shape达到一样
def conv2d_same(inputs, num_outputs, kernel_size, stride, scope=None):
if(stride==1):
return slim.conv2d(inputs,num_outputs,kernel_size,stride=1,padding='SAME',scope=scope)
else:
net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
net = subsample(net, factor=stride)
return net;
若stride==1,则对其做一个卷积取值,相当于全连接神经网络的作用,否则则是对其进行3*3卷积之后采取降采样使得shortcut和residual的shape保持一致
def bottleneck(inputs, depth, depth_bottleneck, stride,
outputs_collections=None, scope=None):
"""
Args:
inputs: A tensor of size [batch, height, width, channels].
depth、depth_bottleneck:、stride三个参数是前面blocks类中的args
rate: An integer, rate for atrous convolution.
outputs_collections: 是收集end_points的collection
"""
with tf.variable_scope(scope, 'bottleneck_v2', ) as sc:
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) # 最后一个维度,即输出通道数
preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
if depth == depth_in:
# 如果残差单元的输入通道数和输出通道数一致,那么按步长对inputs进行降采样
shortcut = subsample(inputs, stride, 'shortcut')
else:
# 如果不一样就按步长和1*1的卷积改变其通道数,使得输入、输出通道数一致
shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
normalizer_fn=None, activation_fn=None,
scope='shortcut')
# 先是一个1*1尺寸,步长1,输出通道数为depth_bottleneck的卷积
residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1, scope='conv1')
# 然后是3*3尺寸,步长为stride,输出通道数为depth_bottleneck的卷积
residual = conv2d_same(residual, depth_bottleneck, 3, stride, scope='conv2')
# 最后是1*1卷积,步长1,输出通道数depth的卷积,得到最终的residual。最后一层没有正则项也没有激活函数
residual = slim.conv2d(residual, depth, [1, 1], stride=1,
normalizer_fn=None, activation_fn=None,
scope='conv3')
# 将降采样的结果和residual相加
output = shortcut + residual
return slim.utils.collect_named_outputs(outputs_collections, sc.name, output)
最后是整个残差块的设计。