https://github.com/zhao62/Deep-Residual-Shrinkage-Networks 代码链接
1.代码
1.1DRSN_keras.py
python版本为3.6
安装tensorflow1.15.0
直接利用tensorflow中的keras
import部分的代码改成:
from __future__ import print_function
import numpy as np
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow_core.python.keras.layers import Lambda
正文代码:
K.set_learning_phase(1)
# Input image dimensions
img_rows, img_cols = 28, 28
# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
def abs_backend(inputs):
return K.abs(inputs)
def expand_dim_backend(inputs):
return K.expand_dims(K.expand_dims(inputs,1),1)
def sign_backend(inputs):
return K.sign(inputs)
def pad_backend(inputs, in_channels, out_channels):
pad_dim = (out_channels - in_channels)//2
inputs = K.expand_dims(inputs,-1)
inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
return K.squeeze(inputs, -1)
# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2):
residual = incoming
in_channels = incoming.get_shape().as_list()[-1]
for i in range(nb_blocks):
identity = residual
if not downsample:
downsample_strides = 1
residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides),
padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)
residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)
# Calculate global means
residual_abs = Lambda(abs_backend)(residual)
abs_mean = GlobalAveragePooling2D()(residual_abs)
# Calculate scaling coefficients
scales = Dense(out_channels, activation=None, kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(abs_mean)
scales = BatchNormalization()(scales)
scales = Activation('relu')(scales)
scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
scales = Lambda(expand_dim_backend)(scales)
# Calculate thresholds
thres = keras.layers.multiply([abs_mean, scales])
# Soft thresholding
sub = keras.layers.subtract([residual_abs, thres])
zeros = keras.layers.subtract([sub, sub])
n_sub = keras.layers.maximum([sub, zeros])
residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
# Downsampling using the pooL-size of (1, 1)
if downsample_strides > 1:
identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
# Zero_padding to match channels
if in_channels != out_channels:
identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
residual = keras.layers.add([residual, identity])
return residual
# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))
# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])
实验结果:
1.2 DRSN_TFLearn.py
from __future__ import division, print_function, absolute_import
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
# Add noise
X = X + np.random.random((50000, 32, 32, 3)) * 0.1
testX = testX + np.random.random((10000, 32, 32, 3)) * 0.1
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y, 10)
testY = tflearn.data_utils.to_categorical(testY, 10)
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2, activation='relu', batch_norm=True,
bias=True, weights_init='variance_scaling',
bias_init='zeros', regularizer='L2', weight_decay=0.0001,
trainable=True, restore=True, reuse=False, scope=None,
name="ResidualBlock"):
# residual shrinkage blocks with channel-wise thresholds
residual = incoming
in_channels = incoming.get_shape().as_list()[-1]
# Variable Scope fix for older TF
try:
vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse)
except Exception:
vscope = tf.compat.v1.variable_op_scope([incoming], scope, name, reuse=reuse)
with vscope as scope:
name = scope.name # TODO
for i in range(nb_blocks):
identity = residual
if not downsample:
downsample_strides = 1
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3,
downsample_strides, 'same', 'linear',
bias, weights_init, bias_init,
regularizer, weight_decay, trainable,
restore)
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3, 1, 'same',
'linear', bias, weights_init,
bias_init, regularizer, weight_decay,
trainable, restore)
# get thresholds and apply thresholding
abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual), axis=2, keep_dims=True), axis=1, keep_dims=True)
scales = tflearn.fully_connected(abs_mean, out_channels // 4, activation='linear', regularizer='L2',
weight_decay=0.0001, weights_init='variance_scaling')
scales = tflearn.batch_normalization(scales)
scales = tflearn.activation(scales, 'relu')
scales = tflearn.fully_connected(scales, out_channels, activation='linear', regularizer='L2',
weight_decay=0.0001, weights_init='variance_scaling')
scales = tf.expand_dims(tf.expand_dims(scales, axis=1), axis=1)
thres = tf.multiply(abs_mean, tflearn.activations.sigmoid(scales))
# soft thresholding
residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual) - thres, 0))
# Downsampling
if downsample_strides > 1:
identity = tflearn.avg_pool_2d(identity, 1,
downsample_strides)
# Projection to new dimension
if in_channels != out_channels:
if (out_channels - in_channels) % 2 == 0:
ch = (out_channels - in_channels) // 2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch]])
else:
ch = (out_channels - in_channels) // 2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch + 1]])
in_channels = out_channels
residual = residual + identity
return residual
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
data_preprocessing=img_prep,
data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]
实验结果
2.模型
2.1残差网络
深度残差收缩网络就是对“深度残差网络”的残差路径进行收缩(软阈值化)的一种网络。
设计思想:在特征学习的过程中,剔除冗余信息也是非常重要的。
残差模块:
如下图所示,长方体表示通道数为C、宽度为W、高为1的特征图;一个残差模块可以包含两个批标准化(Batch Normalization, BN)、两个整流线性单元激活函数(Rectifier Linear Unit activation function, ReLU)、两个卷积层(Convolutional layer)和恒等映射(Identity shortcut)。恒等映射是深度残差网络的核心贡献,极大程度地降低了深度神经网络训练的难度。K表示卷积层中卷积核的个数。
(a)输入特征图尺寸=输出特征图尺寸。在残差模块中,输出特征图的宽度可以发生改变,图(b)将卷积层中卷积核的移动步长设置为2(用/2表示),那么输出特征图的宽度就会减半,变成0.5W。输出特征图的通道数也可以发生改变,图(c)将卷积层中卷积核的个数设置为2C,暑促出特着呢个图的通道数就会变成2C,使得输出特征图的通道数翻倍。图(d)是深度残差网络的整体示意图。
2.2深度残差收缩网络:
深度残差收缩网络面向的是带有“噪声”的信号,将“软阈值化”作为“收缩层”引入残差模块之中,并提出自适应设置阈值的方法。噪声可以理解为与当前任务没有关系的特征信息也就是干扰信息。
软阈值化:
网络结构:
这个子网络所设置的阈值,其实就是(特征图的绝对值的平均值)×(一个系数α)。在sigmoid函数的作用下,α是一个0和1之间的数字。在这种方式下,阈值不仅是一个正数,而且不会太大,即不会使输出全部为零。
图(a) 是改进的残差模块(通道间共享阈值),RSBU-CS。
图(c)是改进的残差模块(不同通道不同阈值),RSBU-CW。
3.体会
软阈值化是信号降噪里一个非常常见的概念,它指的是将一段信号的值,朝着“零”的方向进行收缩。
这种降噪方式有一个前提,接近于零的部分是噪声。但对于很多信号而言,接近于零的部分,可能包含许多有用的信息,不能直接剔除掉,所以通常不会直接对原始信号进行软阈值化处理。传统的思路是将原始信号进行某种变换,将原始信号转换成其他形式的表征。理想情况下,在这种转换后的表征里,接近与零的部分,是无用的噪声,再采用软阈值化对转换后的表征进行处理。最后将软阈值化处理之后的表征,重构回去,获得降噪后的信号。