本文使用tensorflow实现resnet,使用cifar-10、mnist作为数据集进行测试
ResNet网络构建
本文实现的ResNet为ResNet v2版本的block,并实现ResNet-34
res block v2版本,与v1版本的区别主要在于BN、ReLU、conv的顺序不同,文章中提出使用BN->ReLU->conv的顺序效果最好
实现如下:
def res_block_v2(self,input,out_channels,kernel_size=3,stride=1):
"""BN->ReLU->conv->BN->ReLU->conv
:param input:
:param out_channels:
:param kernel_size:
:param stride:
:return:
"""
print("block______________________")
input_channels=input.get_shape().as_list()[3]
inner=input
inner=tf.layers.batch_normalization(inner,training=self.is_training,gamma_initializer=tf.truncated_normal_initializer(stddev=0.1))
inner=tf.nn.relu(inner)
inner=tf.layers.conv2d(inner,out_channels,[kernel_size,kernel_size],strides=[stride,stride],
padding="SAME",use_bias=True,activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
inner = tf.layers.batch_normalization(inner, training=self.is_training,gamma_initializer=tf.truncated_normal_initializer(stddev=0.1))
print(str(inner.get_shape()))
inner = tf.nn.relu(inner)
inner = tf.layers.conv2d(inner, out_channels, [kernel_size, kernel_size], strides=[1, 1],
padding="SAME", use_bias=True, activation=None,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
print(str(inner.get_shape()))
if stride>1 or out_channels>input_channels:
input_layer=tf.layers.conv2d(input,out_channels,[1,1],strides=[stride,stride],
padding="SAME",use_bias=True,activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
else:
input_layer=input
print(str(input_layer.get_shape()))
out=inner+input_layer
print("block______________________end")
return out
基于v2的resnet-34实现如下:
def resnet_v2_34(self,input):
layers=[]
inner=input
print(str(inner.get_shape()))
with tf.variable_scope('conv1'):
inner=tf.layers.conv2d(inner,64,[7,7],padding='SAME',strides=[2,2],kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
layers.append(inner)
print(str(inner.get_shape()))
with tf.variable_scope('conv2'):
inner=tf.layers.max_pooling2d(inner,[3,3],[2,2],padding='SAME')
for i in range(3):
inner=self.res_block_v2(inner,64)
layers.append(inner)
print(str(inner.get_shape()))
with tf.variable_scope('conv3'):
inner=self.res_block_v2(inner,128,stride=2)
for i in range(3):
inner=self.res_block_v2(inner,128)
layers.append(inner)
print(str(inner.get_shape()))
with tf.variable_scope('conv4'):
inner=self.res_block_v2(inner,256,stride=2)
for i in range(5):
inner=self.res_block_v2(inner,256)
layers.append(inner)
print(str(inner.get_shape()))
with tf.variable_scope('conv5'):
inner=self.res_block_v2(inner,512,stride=2)
for i in range(2):
inner=self.res_block_v2(inner,512)
layers.append(inner)
print(str(inner.get_shape()))
with tf.variable_scope('global_average_pool'):
inner=tf.reduce_mean(inner,[1,2])
layers.append(inner)
print(str(inner.get_shape()))
with tf.variable_scope('fc'):
inner=tf.layers.dense(inner,10,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
layers.append(inner)
print(str(inner.get_shape()))
with tf.variable_scope('softmax'):
inner=tf.nn.softmax(inner)
return inner
数据集读取
mnist数据集的读取在tf中有官方API,这里不再赘述,直接上代码
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("./MNIST_data",one_hot=True)
#get train data and label
x_, y_ = mnist.train.next_batch(self.batch_size)
#reshape
xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
yy=np.reshape(y_,[self.batch_size,self.class_num])
#get test data and label
x_, y_ = mnist.test.next_batch(self.batch_size)
#reshape
xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
yy=np.reshape(y_,[self.batch_size,self.class_num])
cifar-10数据集的目录如下:
batches.meta、readme.html
data_batch_1、data_batch_2、data_batch_3、data_batch_4、data_batch_5
test_batch
训练数据都在data_batch中,一共5个文件,每个文件10000个样本,一共50000个训练样本,test_batch中为测试集,也是10000个样本,每个样本像素为32*32,每个样本所占的字节为1(标签)+32*32(像素数据)个
读取cifar-10数据集的代码如下:
import numpy as np
import pickle
# 读取单个的batch文件
def unpickle(file):
with open('D:\pyproject\data\CIFAR\cifar-10-python\cifar-10-batches-py\\' + file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def one_hot(x, n):
"""
convert index representation to one-hot representation
"""
x = np.array(x)
assert x.ndim == 1
return np.eye(n)[x]
def make_one_hot(x,n):
return (np.arange(n)==x[:,None]).astype(np.integer)
class cifar_reader:
def __init__(self,data_dir,image_height=32,image_width=32,image_depth=3,label_bytes=1):
self.data_dir=data_dir
self.image_height=image_height
self.image_width=image_width
self.image_depth=image_depth
self.label_bytes=label_bytes
def train_reader(self,batch_index):
mydata = unpickle('data_batch_'+str(batch_index))
dickeys=mydata.keys()
X = mydata[b'data']
X = np.array(X)
new = X.reshape(10000, 3, 32, 32)
train_data = new.transpose((0, 2, 3, 1))
label = mydata[b'labels']
label=np.array(label)
label=make_one_hot(label,10)
train_label=label
print(train_label.shape)
return train_data,train_label
def next_train_data(self):
return None
Train模块搭建
x、yplaceholder构建
x = tf.placeholder("float32", [self.batch_size, self.input_height, self.input_width,
y = tf.placeholder("float32", [self.batch_size, 10],name='y')
learning_rate,global_step定义
learning_rate = tf.placeholder("float", [])
global_step = tf.Variable(0, trainable=False,name='gloabl_step')
调用resnet进行正向传播
res = resnet.resnet(is_training=True)
net = res.resnet_v2_34(x)
loss和optimizer构建,因为使用了BN,所以需要加入control_dependencies
cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001)+0.0001)
# train op
opt = tf.train.AdamOptimizer(learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
#for batch normalization
with tf.control_dependencies(update_ops):
train_op = opt.minimize(cross_entropy, global_step=global_step)
正确率acc构建
correct_prediction = tf.equal(tf.argmax(net, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
记录每一步的loss和acc,保存模型
tf.summary.scalar('loss', cross_entropy)
tf.summary.scalar('acc', accuracy)
merged = tf.summary.merge_all()
summary_writer_train = tf.summary.FileWriter("logs/" + '/train', sess.graph)
summary_writer_val = tf.summary.FileWriter("logs/" + '/val') # here is no need graph
saver = tf.train.Saver()
训练过程,每训练一些epoch,就保存模型,进行val集的验证
#epoch iteration
for epoch_i in range(self.epoch_size):
#batch iteration
for batch_i in range(100):
#get data and label
x_, y_ = mnist.train.next_batch(self.batch_size)
#reshape
xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
yy=np.reshape(y_,[self.batch_size,self.class_num])
#train
_, loss_value, step, acc,rs = sess.run([train_op, cross_entropy, global_step, accuracy,merged],
feed_dict={x: xx, y: yy, learning_rate: 0.00001})
print("After %d train epoch,loss on training batch is %g.and accuracy is %g." % (epoch_i, loss_value, acc))
#train summary write
summary_writer_train.add_summary(rs, epoch_i)
#run val data
if epoch_i % (1)==0:
self.val(mnist,epoch_i,sess,summary_writer_val,merged)
#save model
if epoch_i % (self.epoch_size-1) == 0:
print("---After %d train epoch,loss on training batch is %g.and accuracy is %g." % (
epoch_i, loss_value, acc))
saver.save(sess, os.path.join(self.MODEL_SAVE_PATH, self.MODEL_NAME), global_step=global_step)
val集、test集运行代码如下:
def val(self,mnist,epoch_i,sess,summary_writer_val,merged):
'''val function
:param mnist: mnist data
:param epoch_i: index of epoch
:param sess: same sess from train
:param summary_writer_val:
:param merged: same merged from train
:return:
'''
gragh = tf.get_default_graph()
# get placehold
x = gragh.get_tensor_by_name("x:0")
y = gragh.get_tensor_by_name("y:0")
# get predict
net = gragh.get_tensor_by_name("softmax/Softmax:0")
cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001) + 0.0001)
correct_prediction = tf.equal(tf.argmax(net, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
acc_sum = 0
for batch_i in range(156):
x_, y_ = mnist.test.next_batch(self.batch_size)
xx = np.reshape(x_, [self.batch_size, self.input_height, self.input_width, self.input_channel])
yy = np.reshape(y_, [self.batch_size, self.class_num])
loss_value, acc,res = sess.run([cross_entropy, accuracy,merged],
feed_dict={x: xx, y: yy})
acc_sum+=acc
print("test total acc:" + str(acc_sum / 156))
summary_writer_val.add_summary(res,epoch_i)
def test(self,mnist):
'''test function
:param mnist:
:return:
'''
# get model from meta
saver = tf.train.import_meta_graph('./model/mnist_resnet_model.ckpt-5000.meta')
gragh = tf.get_default_graph() # 获取当前图,为了后续训练时恢复变量
tensor_name_list = [tensor.name for tensor in gragh.as_graph_def().node] # 得到当前图中所有变量的名称
print(tensor_name_list)
# get placehold
x = gragh.get_tensor_by_name("x:0")
y = gragh.get_tensor_by_name("y:0")
# get prediction
net = gragh.get_tensor_by_name("softmax/Softmax:0")
#loss
cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001) + 0.0001)
#tf.summary.scalar('test_loss', cross_entropy)
correct_prediction = tf.equal(tf.argmax(net, 1), tf.argmax(y, 1))
#acc
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
#tf.summary.scalar('test_acc', accuracy)
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('./model/'))
acc_sum=0
for batch_i in range(156):
x_, y_ = mnist.test.next_batch(self.batch_size)
xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
yy=np.reshape(y_,[self.batch_size,self.class_num])
loss_value, acc = sess.run([ cross_entropy, accuracy],
feed_dict={x: xx, y: yy})
print("After %d train batch,loss on training batch is %g.and accuracy is %g." % (batch_i, loss_value, acc))
acc_sum=acc+acc_sum
print("total acc:"+str(acc_sum/156))
cifar-10的train、test、val类似,只不过采用了tf中自带的loss
#loss
#cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001) + 0.0001)
#cross_entropy=tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=net,labels=y))
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=net,labels=y))
opt = tf.train.AdamOptimizer(learning_rate)
最终在mnist训练了20个epoch得到结果如下
(tensorboard使用方式:在cmd中tensorboard --logdir=D:\pyproject\cifar\resnet\logs)
完整代码在github上:https://github.com/panxiaobai/ResNet_MNIST_TF