[深度学习从入门到女装]tensorflow实战-ResNet(cifar-10,MNIST)

本文使用tensorflow实现resnet,使用cifar-10、mnist作为数据集进行测试

 

ResNet网络构建

本文实现的ResNet为ResNet v2版本的block,并实现ResNet-34

res block v2版本,与v1版本的区别主要在于BN、ReLU、conv的顺序不同,文章中提出使用BN->ReLU->conv的顺序效果最好

实现如下:

    def res_block_v2(self,input,out_channels,kernel_size=3,stride=1):
        """BN->ReLU->conv->BN->ReLU->conv

        :param input:
        :param out_channels:
        :param kernel_size:
        :param stride:
        :return:
        """
        print("block______________________")
        input_channels=input.get_shape().as_list()[3]
        inner=input
        inner=tf.layers.batch_normalization(inner,training=self.is_training,gamma_initializer=tf.truncated_normal_initializer(stddev=0.1))
        inner=tf.nn.relu(inner)
        inner=tf.layers.conv2d(inner,out_channels,[kernel_size,kernel_size],strides=[stride,stride],
                               padding="SAME",use_bias=True,activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))

        inner = tf.layers.batch_normalization(inner, training=self.is_training,gamma_initializer=tf.truncated_normal_initializer(stddev=0.1))
        print(str(inner.get_shape()))
        inner = tf.nn.relu(inner)
        inner = tf.layers.conv2d(inner, out_channels, [kernel_size, kernel_size], strides=[1, 1],
                                 padding="SAME", use_bias=True, activation=None,
                                 kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
        print(str(inner.get_shape()))

        if stride>1 or out_channels>input_channels:
            input_layer=tf.layers.conv2d(input,out_channels,[1,1],strides=[stride,stride],
                               padding="SAME",use_bias=True,activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
        else:
            input_layer=input
        print(str(input_layer.get_shape()))
        out=inner+input_layer
        print("block______________________end")
        return out

 

基于v2的resnet-34实现如下:

    def resnet_v2_34(self,input):
        layers=[]
        inner=input
        print(str(inner.get_shape()))
        with tf.variable_scope('conv1'):
            inner=tf.layers.conv2d(inner,64,[7,7],padding='SAME',strides=[2,2],kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
            layers.append(inner)
            print(str(inner.get_shape()))

        with tf.variable_scope('conv2'):
            inner=tf.layers.max_pooling2d(inner,[3,3],[2,2],padding='SAME')
            for i in range(3):
                inner=self.res_block_v2(inner,64)
            layers.append(inner)
            print(str(inner.get_shape()))

        with tf.variable_scope('conv3'):
            inner=self.res_block_v2(inner,128,stride=2)
            for i in range(3):
                inner=self.res_block_v2(inner,128)
            layers.append(inner)
            print(str(inner.get_shape()))

        with tf.variable_scope('conv4'):
            inner=self.res_block_v2(inner,256,stride=2)
            for i in range(5):
                inner=self.res_block_v2(inner,256)
            layers.append(inner)
            print(str(inner.get_shape()))
        
        with tf.variable_scope('conv5'):
            inner=self.res_block_v2(inner,512,stride=2)
            for i in range(2):
                inner=self.res_block_v2(inner,512)
            layers.append(inner)
            print(str(inner.get_shape()))
        
        with tf.variable_scope('global_average_pool'):
            inner=tf.reduce_mean(inner,[1,2])
            layers.append(inner)
            print(str(inner.get_shape()))

        with tf.variable_scope('fc'):
            inner=tf.layers.dense(inner,10,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))
            layers.append(inner)
            print(str(inner.get_shape()))


        with tf.variable_scope('softmax'):
            inner=tf.nn.softmax(inner)

        return inner

 

数据集读取

mnist数据集的读取在tf中有官方API,这里不再赘述,直接上代码

from tensorflow.examples.tutorials.mnist import input_data


mnist=input_data.read_data_sets("./MNIST_data",one_hot=True)



#get train data and label
x_, y_ = mnist.train.next_batch(self.batch_size)
#reshape
xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
yy=np.reshape(y_,[self.batch_size,self.class_num])


#get test data and label
x_, y_ = mnist.test.next_batch(self.batch_size)
#reshape
xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
yy=np.reshape(y_,[self.batch_size,self.class_num])

 

cifar-10数据集的目录如下:

batches.meta、readme.html

data_batch_1、data_batch_2、data_batch_3、data_batch_4、data_batch_5

test_batch

训练数据都在data_batch中,一共5个文件,每个文件10000个样本,一共50000个训练样本,test_batch中为测试集,也是10000个样本,每个样本像素为32*32,每个样本所占的字节为1(标签)+32*32(像素数据)个

读取cifar-10数据集的代码如下:

import numpy as np
import pickle



# 读取单个的batch文件
def unpickle(file):
    with open('D:\pyproject\data\CIFAR\cifar-10-python\cifar-10-batches-py\\' + file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def one_hot(x, n):
    """
    convert index representation to one-hot representation
    """
    x = np.array(x)
    assert x.ndim == 1
    return np.eye(n)[x]

def make_one_hot(x,n):
    return (np.arange(n)==x[:,None]).astype(np.integer)

class cifar_reader:
    def __init__(self,data_dir,image_height=32,image_width=32,image_depth=3,label_bytes=1):
        self.data_dir=data_dir
        self.image_height=image_height
        self.image_width=image_width
        self.image_depth=image_depth
        self.label_bytes=label_bytes

    def train_reader(self,batch_index):
        mydata = unpickle('data_batch_'+str(batch_index))
        dickeys=mydata.keys()

        X = mydata[b'data']
        X = np.array(X)
        new = X.reshape(10000, 3, 32, 32)
        train_data = new.transpose((0, 2, 3, 1))

        label = mydata[b'labels']
        label=np.array(label)
        label=make_one_hot(label,10)
        train_label=label
        print(train_label.shape)
        return train_data,train_label

    def next_train_data(self):
        return None


 

Train模块搭建

x、yplaceholder构建

x = tf.placeholder("float32", [self.batch_size, self.input_height, self.input_width, 
y = tf.placeholder("float32", [self.batch_size, 10],name='y')

learning_rate,global_step定义

learning_rate = tf.placeholder("float", [])
global_step = tf.Variable(0, trainable=False,name='gloabl_step')

调用resnet进行正向传播

res = resnet.resnet(is_training=True)
net = res.resnet_v2_34(x)

loss和optimizer构建,因为使用了BN,所以需要加入control_dependencies

cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001)+0.0001)

# train op
opt = tf.train.AdamOptimizer(learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
#for batch normalization
with tf.control_dependencies(update_ops):
    train_op = opt.minimize(cross_entropy, global_step=global_step)

正确率acc构建

correct_prediction = tf.equal(tf.argmax(net, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))

记录每一步的loss和acc,保存模型

tf.summary.scalar('loss', cross_entropy)
tf.summary.scalar('acc', accuracy)
merged = tf.summary.merge_all()
summary_writer_train = tf.summary.FileWriter("logs/" + '/train', sess.graph)
summary_writer_val = tf.summary.FileWriter("logs/" + '/val')  # here is no need graph
saver = tf.train.Saver()

训练过程,每训练一些epoch,就保存模型,进行val集的验证

        #epoch iteration
        for epoch_i in range(self.epoch_size):
            #batch iteration
            for batch_i in range(100):
                #get data and label
                x_, y_ = mnist.train.next_batch(self.batch_size)
                #reshape
                xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
                yy=np.reshape(y_,[self.batch_size,self.class_num])
                #train
                _, loss_value, step, acc,rs = sess.run([train_op, cross_entropy, global_step, accuracy,merged],
                                                    feed_dict={x: xx, y: yy, learning_rate: 0.00001})


            print("After %d train epoch,loss on training batch is %g.and accuracy is %g." % (epoch_i, loss_value, acc))
            #train summary write
            summary_writer_train.add_summary(rs, epoch_i)
            #run val data
            if epoch_i % (1)==0:
                self.val(mnist,epoch_i,sess,summary_writer_val,merged)
            #save model
            if epoch_i % (self.epoch_size-1) == 0:
                print("---After %d train epoch,loss on training batch is %g.and accuracy is %g." % (
                epoch_i, loss_value, acc))
                saver.save(sess, os.path.join(self.MODEL_SAVE_PATH, self.MODEL_NAME), global_step=global_step)

val集、test集运行代码如下:

    def val(self,mnist,epoch_i,sess,summary_writer_val,merged):
        '''val function

        :param mnist: mnist data
        :param epoch_i: index of epoch
        :param sess: same sess from train
        :param summary_writer_val:
        :param merged: same merged from train
        :return:
        '''
        gragh = tf.get_default_graph()

        # get placehold
        x = gragh.get_tensor_by_name("x:0")
        y = gragh.get_tensor_by_name("y:0")
        # get predict
        net = gragh.get_tensor_by_name("softmax/Softmax:0")

        cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001) + 0.0001)

        correct_prediction = tf.equal(tf.argmax(net, 1), tf.argmax(y, 1))

        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))

        acc_sum = 0

        for batch_i in range(156):
            x_, y_ = mnist.test.next_batch(self.batch_size)
            xx = np.reshape(x_, [self.batch_size, self.input_height, self.input_width, self.input_channel])
            yy = np.reshape(y_, [self.batch_size, self.class_num])

            loss_value, acc,res = sess.run([cross_entropy, accuracy,merged],
                                           feed_dict={x: xx, y: yy})

            acc_sum+=acc

        print("test total acc:" + str(acc_sum / 156))
        summary_writer_val.add_summary(res,epoch_i)

    def test(self,mnist):
        '''test function

        :param mnist:
        :return:
        '''
        # get model from meta
        saver = tf.train.import_meta_graph('./model/mnist_resnet_model.ckpt-5000.meta')
        gragh = tf.get_default_graph()  # 获取当前图,为了后续训练时恢复变量
        tensor_name_list = [tensor.name for tensor in gragh.as_graph_def().node]  # 得到当前图中所有变量的名称
        print(tensor_name_list)

        # get placehold
        x = gragh.get_tensor_by_name("x:0")
        y = gragh.get_tensor_by_name("y:0")

        # get prediction
        net = gragh.get_tensor_by_name("softmax/Softmax:0")

        #loss
        cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001) + 0.0001)
        #tf.summary.scalar('test_loss', cross_entropy)

        correct_prediction = tf.equal(tf.argmax(net, 1), tf.argmax(y, 1))
        #acc
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))

        #tf.summary.scalar('test_acc', accuracy)



        with tf.Session() as sess:
            saver.restore(sess, tf.train.latest_checkpoint('./model/'))

            acc_sum=0

            for batch_i in range(156):
                x_, y_ = mnist.test.next_batch(self.batch_size)
                xx=np.reshape(x_,[self.batch_size,self.input_height,self.input_width,self.input_channel])
                yy=np.reshape(y_,[self.batch_size,self.class_num])

                loss_value,  acc = sess.run([ cross_entropy, accuracy],
                                                    feed_dict={x: xx, y: yy})
                print("After %d train batch,loss on training batch is %g.and accuracy is %g." % (batch_i, loss_value, acc))
                acc_sum=acc+acc_sum

            print("total acc:"+str(acc_sum/156))

cifar-10的train、test、val类似,只不过采用了tf中自带的loss

#loss
#cross_entropy = -tf.reduce_sum(y * tf.log(net + 0.0001) + 0.0001)
#cross_entropy=tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=net,labels=y))
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=net,labels=y))
opt = tf.train.AdamOptimizer(learning_rate)

 

最终在mnist训练了20个epoch得到结果如下

(tensorboard使用方式:在cmd中tensorboard --logdir=D:\pyproject\cifar\resnet\logs)

 

完整代码在github上:https://github.com/panxiaobai/ResNet_MNIST_TF

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值