# 利用Tensorflow的Slim API实现卷积神经网络

import os
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.contrib.slim as slim
import time
from tensorflow.examples.tutorials.mnist import input_data
# %matplotlib inline

#装载minist数据集，请把该数据集的四个文件拷贝到程序所在目录的data子目录下
trainimg   = mnist.train.images
trainlabel = mnist.train.labels
valimg     = mnist.validation.images
vallabel   = mnist.validation.labels
testimg    = mnist.test.images
testlabel  = mnist.test.labels


jupyter notebook运行结果：

Extracting Z:\CarlWu\temp\machinelearning_course\Hadoop_cn\deeplearning\DeepLearningCourseCodes-master\04_CNN_advances\data/train-images-idx3-ubyte.gz


## 定义神经网络模型

n_input = 784
n_classes = 10
x = tf.placeholder("float", [None, n_input])
y = tf.placeholder("float", [None, n_classes])
is_training = tf.placeholder(tf.bool)

def lrelu(x, leak=0.2, name='lrelu'):
with tf.variable_scope(name):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)

def CNN(inputs, is_training=True):
x   = tf.reshape(inputs, [-1, 28, 28, 1])
batch_norm_params = {'is_training': is_training, 'decay': 0.9
init_func = tf.truncated_normal_initializer(stddev=0.01)
net = slim.conv2d(x, 32, [5, 5], padding='SAME'
, activation_fn       = lrelu
, weights_initializer = init_func
, normalizer_fn       = slim.batch_norm
, normalizer_params   = batch_norm_params
, scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.conv2d(x, 64, [5, 5], padding='SAME'
, activation_fn       = lrelu
, weights_initializer = init_func
, normalizer_fn       = slim.batch_norm
, normalizer_params   = batch_norm_params
, scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.flatten(net, scope='flatten3')
net = slim.fully_connected(net, 1024
, activation_fn       = lrelu
, weights_initializer = init_func
, normalizer_fn       = slim.batch_norm
, normalizer_params   = batch_norm_params
, scope='fc4')
net = slim.dropout(net, keep_prob=0.7, is_training=is_training, scope='dr')
out = slim.fully_connected(net, n_classes
, activation_fn=None, normalizer_fn=None, scope='fco')
return out
print ("神经网络准备完毕")



## 定义图结构

# PREDICTION
pred = CNN(x, is_training)

# LOSS AND OPTIMIZER
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels=y, logits=pred))
corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(corr, "float"))

# INITIALIZER
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#检查变量
print ("=================== TRAINABLE VARIABLES ===================")
t_weights = tf.trainable_variables()
var_names_list = [v.name for v in tf.trainable_variables()]
for i in range(len(t_weights)):
wval = sess.run(t_weights[i])
print ("[%d/%d] [%s] / SAHPE IS %s"
% (i, len(t_weights), var_names_list[i], wval.shape,))


Jupyter notebook输出结果：

=================== TRAINABLE VARIABLES ===================
[0/8] [conv1/weights:0] / SAHPE IS (5, 5, 1, 32)
[1/8] [conv1/BatchNorm/beta:0] / SAHPE IS (32,)
[2/8] [conv2/weights:0] / SAHPE IS (5, 5, 1, 64)
[3/8] [conv2/BatchNorm/beta:0] / SAHPE IS (64,)
[4/8] [fc4/weights:0] / SAHPE IS (12544, 1024)
[5/8] [fc4/BatchNorm/beta:0] / SAHPE IS (1024,)
[6/8] [fco/weights:0] / SAHPE IS (1024, 10)
[7/8] [fco/biases:0] / SAHPE IS (10,)



＃将模型存储在nets子目录下的一个目录中
savedir = "nets/cnn_mnist_modern/"
saver = tf.train.Saver(max_to_keep=100)
save_step = 4
if not os.path.exists(savedir):
os.makedirs(savedir)

#增加图片数据，训练模型
def augment_img(xs):
out  = np.copy(xs)
xs_r = np.reshape(xs, [-1, 28, 28])
for i in range(xs_r.shape[0]):
xs_img = xs_r[i, :, :]
bg_value = 0
# ROTATE
angle = np.random.randint(-15, 15, 1).astype(float)
xs_img = ndimage.rotate(xs_img, angle, reshape=False, cval=bg_value)
# ZOOM
rg = 0.1
zoom_factor = np.random.uniform(1., 1.+rg)
h, w = xs_img.shape[:2]
zh   = int(np.round(zoom_factor * h))
zw   = int(np.round(zoom_factor * w))
top  = (zh - h) // 2
left = (zw - w) // 2
zoom_tuple = (zoom_factor,) * 2 + (1,) * (xs_img.ndim - 2)
temp = ndimage.zoom(xs_img[top:top+zh, left:left+zw], zoom_tuple)
trim_top  = ((temp.shape[0] - h) // 2)
trim_left = ((temp.shape[1] - w) // 2)
xs_img = temp[trim_top:trim_top+h, trim_left:trim_left+w]
# SHIFT
shift = np.random.randint(-3, 3, 2)
xs_img = ndimage.shift(xs_img, shift, cval=bg_value)
# RESHAPE
xs_v = np.reshape(xs_img, [1, -1])
out[i, :] = xs_v
return out


# PARAMETERS
training_epochs = 50
batch_size      = 50
display_step    = 3
val_acc         = 0
val_acc_max     = 0
# OPTIMIZE
currentTime = time.time()
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
# ITERATION
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# AUGMENT DATA
batch_xs = augment_img(batch_xs)
feeds = {x: batch_xs, y: batch_ys, is_training: True}
sess.run(optm, feed_dict=feeds)
avg_cost += sess.run(cost, feed_dict=feeds)
avg_cost = avg_cost / total_batch
# DISPLAY
if (epoch+1) % display_step == 0:
print('time spent is ', (time.time()-currentTime))
currentTime = time.time()
print ("Epoch: %03d/%03d cost: %.9f" % (epoch+1, training_epochs, avg_cost))
randidx = np.random.permutation(trainimg.shape[0])[:500]
feeds = {x: trainimg[randidx], y: trainlabel[randidx], is_training: False}
train_acc = sess.run(accr, feed_dict=feeds)
print (" TRAIN ACCURACY: %.5f" % (train_acc))

#下面这段代码计算在验证数据集上的准确度，原来的代码不能工作
#feeds = {x: valimg, y: vallabel, is_training: False}
#val_acc = sess.run(accr, feed_dict=feeds)

total_batch_val=int(valimg.shape[0]/batch_size)
print("在验证数据集上分%d批计算准确度", % total_batch_val)
val_acc_sum = 0.0
for j in range(total_batch_val):
feeds = {x: valimg[j*batch_size:min((j+1)*batch_size,valimg.shape[0]-1)],
y: vallabel[j*batch_size:min((j+1)*batch_size,valimg.shape[0]-1)],
is_training: False}

val_acc = sess.run(accr, feed_dict=feeds)
val_acc_sum = val_acc_sum + val_acc

val_acc = val_acc_sum/total_batch_val
#代码修改结束

print (" 在验证数据集上的准确度为: %.5f" % (val_acc))
# SAVE
if (epoch+1) % save_step == 0:
savename = savedir + "net-" + str(epoch) + ".ckpt"
saver.save(sess=sess, save_path=savename)
print (" [%s] SAVED." % (savename))
# MAXIMUM VALIDATION ACCURACY
if val_acc > val_acc_max:
val_acc_max = val_acc
best_epoch = epoch
print ("\x1b[31m BEST EPOCH UPDATED!! [%d] \x1b[0m" % (best_epoch))
print ("OPTIMIZATION FINISHED")

time spent is  595.5124831199646
Epoch: 003/050 cost: 0.056146707
TRAIN ACCURACY: 0.99200
total batch val: total_batch_val  100
VALIDATION ACCURACY: 0.99160
BEST EPOCH UPDATED!! [2]
[nets/cnn_mnist_modern/net-3.ckpt] SAVED.
time spent is  644.9777743816376
Epoch: 006/050 cost: 0.052948017
TRAIN ACCURACY: 0.99400
total batch val: total_batch_val  100
VALIDATION ACCURACY: 0.99020
[nets/cnn_mnist_modern/net-7.ckpt] SAVED.
time spent is  689.395813703537
Epoch: 009/050 cost: 0.052893652
TRAIN ACCURACY: 0.99200
total batch val: total_batch_val  100
VALIDATION ACCURACY: 0.99180
BEST EPOCH UPDATED!! [8]
time spent is  598.4757721424103
...
...
Epoch: 042/050 cost: 0.037603188
TRAIN ACCURACY: 0.99200
total batch val: total_batch_val  100
VALIDATION ACCURACY: 0.99500
[nets/cnn_mnist_modern/net-43.ckpt] SAVED.
time spent is  689.3062949180603
Epoch: 045/050 cost: 0.034730853
TRAIN ACCURACY: 0.99400
total batch val: total_batch_val  100
VALIDATION ACCURACY: 0.99520
BEST EPOCH UPDATED!! [44]
time spent is  616.6805007457733
Epoch: 048/050 cost: 0.035798393
TRAIN ACCURACY: 0.99800
total batch val: total_batch_val  100
VALIDATION ACCURACY: 0.99340
[nets/cnn_mnist_modern/net-47.ckpt] SAVED.
OPTIMIZATION FINISHED

best_epoch = 47
restorename = savedir + "net-" + str(best_epoch) + ".ckpt"
saver.restore(sess, restorename)
feeds = {x: testimg, y: testlabel, is_training: False}
test_acc = sess.run(accr, feed_dict=feeds)
print ("TEST ACCURACY: %.5f" % (test_acc))

## 总结下遇到的问题及解决方法：

feeds = {x: valimg, y: vallabel, is_training: False}

val_acc = sess.run(accr, feed_dict=feeds)

        total_batch_val=int(valimg.shape[0]/batch_size)
print("在验证数据集上分%d批计算准确度", % total_batch_val)
val_acc_sum = 0.0
for j in range(total_batch_val):
feeds = {x: valimg[j*batch_size:min((j+1)*batch_size,valimg.shape[0]-1)],
y: vallabel[j*batch_size:min((j+1)*batch_size,valimg.shape[0]-1)],
is_training: False}

val_acc = sess.run(accr, feed_dict=feeds)
val_acc_sum = val_acc_sum + val_acc

val_acc = val_acc_sum/total_batch_val
#代码修改结束


• 广告
• 抄袭
• 版权
• 政治
• 色情
• 无意义
• 其他

120