tensorflow CNN实战cifar-10(附完整代码)

数据集介绍
Cifar-10数据集是深度学习领域一个常见的数据集。Cifar-10由60000张32*32的RGB彩色图片构成,一共包含有飞机、汽车、鸟、毛、鹿、狗、青蛙、马、船、卡车这10个类别。50000张训练,10000张测试。常被用来作为分类任务来评价深度学习框架和模型的优劣。比较知名的模型如AlexNet、NIN、ResNet等都曾在Cifar-10数据集上来评价自己的性能。它还有一姐妹级的数据集Cifar-100,顾名思义就是包含100个类别,数据更加复杂。关于Cifar数据集的相关介绍以及数据的下载可见 官网 。正是因为Cifar-10数据集不大、类别明确、获取方便、训练简单,同时模型的可参照性强,因此作为深度学习的初学者作为一个进阶的内容,再适合不过了(本文的前提是对tensorflow的使用及相关参数有一定的了解)。


网络结构
conv1->relu1->pool1-> conv2->relu2->pool2->conv3->relu3>pool3->fc1->dropout1->fc2->out.
所有的卷积核采用的都是3x3大小, pool的核也是3x3,步长为2x2.

中间加入bn层.batch_size设置成100, 迭代500次, 之后减小学习率.

main.py

import tensorflow as tf
import cifar_reader

batch_size = 100
step = 0
train_iter = 50000
display_step = 10

# for key in data:
#     print(key)
input_x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3])
y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
keep_prob = tf.placeholder(tf.float32)
is_traing = tf.placeholder(tf.bool)

####conv1
W1 = tf.Variable(tf.truncated_normal([3, 3, 3, 64], dtype=tf.float32, stddev=5e-2))
conv_1 = tf.nn.conv2d(input_x, W1, strides=(1, 1, 1, 1), padding="VALID")
print(conv_1)*

bn1 = tf.layers.batch_normalization(conv_1, training=is_traing)

relu_1 = tf.nn.relu(bn1)
print(relu_1)

pool_1 = tf.nn.max_pool(relu_1, strides=[1, 2, 2, 1], padding="VALID", ksize=[1, 3, 3, 1])
print(pool_1)

####conv2
W2 = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 128], dtype=tf.float32, stddev=5e-2))
conv_2 = tf.nn.conv2d(pool_1, W2, strides=[1, 1, 1, 1], padding="SAME")
print(conv_2)

bn2 = tf.layers.batch_normalization(conv_2, training=is_traing)

relu_2 = tf.nn.relu(bn2)
print(relu_2)

pool_2 = tf.nn.max_pool(relu_2, strides=[1, 2, 2, 1], ksize=[1, 3, 3, 1], padding="VALID")
print(pool_2)

####conv3
W3 = tf.Variable(tf.truncated_normal(shape=[3, 3, 128, 256], dtype=tf.float32, stddev=1e-1))
conv_3 = tf.nn.conv2d(pool_2, W3, strides=[1, 1, 1, 1], padding="SAME")
print(conv_3)

bn3 = tf.layers.batch_normalization(conv_3, training=is_traing)

relu_3 = tf.nn.relu(bn3)
print(relu_3)

pool_3 = tf.nn.max_pool(relu_3, strides=[1, 2, 2, 1], ksize=[1, 3, 3, 1], padding="VALID")
print(pool_3)

#fc1
dense_tmp = tf.reshape(pool_3, shape=[-1, 2*2*256])
print(dense_tmp)

fc1 = tf.Variable(tf.truncated_normal(shape=[2*2*256, 1024], stddev=0.04))

bn_fc1 = tf.layers.batch_normalization(tf.matmul(dense_tmp, fc1), training=is_traing)

dense1 = tf.nn.relu(bn_fc1)
dropout1 = tf.nn.dropout(dense1, keep_prob)

#fc2
fc2 = tf.Variable(tf.truncated_normal(shape=[1024, 10], stddev=0.04))
out = tf.matmul(dropout1, fc2)
print(out)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=out, labels=y))
optimizer = tf.train.AdamOptimizer(0.01).minimize(cost)

dr = cifar_reader.Cifar10DataReader(cifar_folder="./cifar-10-batches-py/")

# 测试网络
correct_pred = tf.equal(tf.argmax(out, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# 初始化所有的共享变量
init = tf.initialize_all_variables()

saver = tf.train.Saver()

# 开启一个训练
with tf.Session() as sess:
    sess.run(init)
    # saver.restore(sess, "model_tmp/cifar10_demo.ckpt")
    step = 1

    # Keep training until reach max iterations
    while step * batch_size < train_iter:
        step += 1
        batch_xs, batch_ys = dr.next_train_data(batch_size)
        # 获取批数据,计算精度, 损失值
        opt, acc, loss = sess.run([optimizer, accuracy, cost],
                                  feed_dict={input_x: batch_xs, y: batch_ys, keep_prob: 0.6, is_traing: True})
        if step % display_step == 0:
            print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + ", Training Accuracy= " + "{:.5f}".format(acc))
    print ("Optimization Finished!")

    # 计算测试精度
    num_examples = 10000
    d, l = dr.next_test_data(num_examples)
    print ("Testing Accuracy:", sess.run(accuracy, feed_dict={input_x: d, y: l, keep_prob: 1.0, is_traing: True}))
    saver.save(sess, "model_tmp/cifar10_demo.ckpt")

cifar_reader.py

import pickle
import os
import numpy as np
import math

class Cifar10DataReader():
    def __init__(self, cifar_folder, onehot=True):
        self.cifar_folder = cifar_folder
        self.onehot = onehot
        self.data_index = 1
        self.read_next = True
        self.data_label_train = None
        self.data_label_test = None
        self.batch_index = 0

    def unpickle(self, f):
        fo = open(f, 'rb')
        d = pickle.load(fo, encoding="bytes")
        fo.close()
        return d

    def next_train_data(self, batch_size=100):
        assert 10000 % batch_size == 0, "10000%batch_size!=0"
        rdata = None
        rlabel = None
        if self.read_next:
            f = os.path.join(self.cifar_folder, "data_batch_%s" % (self.data_index))
            print('read:', f)

            dic_train = self.unpickle(f)
            self.data_label_train = list(zip(dic_train[b'data'], dic_train[b'labels']))  # label 0~9
            np.random.shuffle(self.data_label_train)

            self.read_next = False
            if self.data_index == 5:
                self.data_index = 1
            else:
                self.data_index += 1

        if self.batch_index < len(self.data_label_train) // batch_size:
            # print self.batch_index
            datum = self.data_label_train[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
            self.batch_index += 1
            rdata, rlabel = self._decode(datum, self.onehot)
        else:
            self.batch_index = 0
            self.read_next = True
            return self.next_train_data(batch_size=batch_size)

        return rdata, rlabel

    def _decode(self, datum, onehot):
        rdata = list()
        rlabel = list()
        if onehot:
            for d, l in datum:
                rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
                hot = np.zeros(10)
                hot[int(l)] = 1
                rlabel.append(hot)
        else:
            for d, l in datum:
                rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
                rlabel.append(int(l))
        return rdata, rlabel

    def next_test_data(self, batch_size=100):
        if self.data_label_test is None:
            f = os.path.join(self.cifar_folder, "test_batch")
            print('read:', f)

            dic_test = self.unpickle(f)
            data = dic_test[b'data']
            labels = dic_test[b'labels']  # 0~9
            self.data_label_test = list(zip(data, labels))

        np.random.shuffle(self.data_label_test)
        datum = self.data_label_test[0:batch_size]

        return self._decode(datum, self.onehot)
1.网络无法收敛问题: 去掉一层全连接层,网络太深有时会导致无法收敛.
2.增加dropout层, bn层可减少过拟合,增加识别率60%到75%

多次减小学习率迭代后最终的准确率在80%以上.


  • 12
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值