tensorflow实现CIFAR-10图片分类

本篇文章主要是利用tensorflow来构建简单的神经网络,主要写了怎样对数据集进行预处理并从数据集中读取数据,利用CIFAR-10数据集来实现图片的分类。数据集主要包括10类不同的图片,一共有60000张图片,50000张图片作为训练集,10000张图片作为测试集,每张图片的大小为32×32×3(彩色图片)。数据集格式如下所示:

首先http://www.cs.toronto.edu/~kriz/cifar.html 下载cifar数据集,选择python版本

打开jupyter notebook

在相同目录下创建文件, 首先查看数据集的文件目录

import pickle
import numpy
import os
CIFAR_DIR="./cifar-10-batches-py"
print(os.listdir(CIFAR_DIR))

['batches.meta', 'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', 'readme.html', 'test_batch']
#打开第一个训练文件,输出数据类型
with open(os.path.join(CIFAR_DIR,"data_batch_1"),'rb')as f:
    data=pickle.load(f,encoding='bytes')
    print(type(data))

<class 'dict'>

#输出数据字典里所有的键
print(data.keys())

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

#定义读取数据的函数
def load_data(filename):
    """read data from data file"""
    with open(filename,'rb') as f:
        data=pickle.load(f,encoding='bytes')
        return data[b'data'],data[b'labels']
#定义一个类,在这个类中对数据进行预处理,如何从数据集中提取数据
class CifarData:
    def __init__(self,filenames,need_shuffle):
        all_data=[]
        all_labels=[]
        for filename in filenames:
            data,labels=load_data(filename)
            all_data.append(data)
            all_labels.append(labels)
            
            #绑定两个数据
#             for item,label in zip(data,labels):
#                 if label in [0,1]:
#                     all_data.append(item)
#                     all_labels.append(label)
            #纵向上将数据合并在一起
        self._data=np.vstack(all_data)
        #归一化数据
        self._data=self._data/127.5-1
        self._labels=np.hstack(all_labels)
        print (self._data.shape)
        print (self._labels.shape)
        self._num_examples=self._data.shape[0]
        self._need_shuffle=need_shuffle
        self._indicator=0
        if self._need_shuffle:
            self._shuffle_data()
            
    def _shuffle_data(self):
        p=np.random.permutation(self._num_examples)
        self._data=self._data[p]
        self._labels=self._labels[p]
        
    def next_batch(self,batch_size):
        end_indicator=self._indicator+batch_size
        if end_indicator>self._num_examples:
            if self._need_shuffle:
                self._shuffle_data()
                self._indicator=0
                end_indicator=batch_size
            else:
                raise Exception("error")
        if end_indicator>self._num_examples:
            raise Exception("error")
        batch_data=self._data[self._indicator:end_indicator]
        batch_labels=self._labels[self._indicator:end_indicator]
        self._indicator=end_indicator
        return batch_data,batch_labels

#从数据集中读取训练集和测试集的数据
train_filenames=[os.path.join(CIFAR_DIR,'data_batch_%d' % i) for i in range(1,6)]
test_filenames=[os.path.join(CIFAR_DIR,'test_batch')]
train_data=CifarData(train_filenames,True)
test_data=CifarData(test_filenames,False)

#定义模型,计算模型的损失函数、优化器和准确率

tf.reset_default_graph() #清除默认图形堆栈并重置全局默认图形
x=tf.placeholder(tf.float32,[None,3072])
y=tf.placeholder(tf.int64,[None])

#十个类别
w=tf.get_variable('w',[x.get_shape()[-1],10],initializer=tf.random_normal_initializer(0,1))

b=tf.get_variable('b',[10],initializer=tf.constant_initializer(0.0))
y_=tf.matmul(x,w)+b
"""
p_y=tf.nn.softmax(y_)
y_one_hot=tf.one_hot(y,10,dtype=tf.float32)
loss=tf.reduce_mean(tf.square(y_one_hot-p_y))
"""
#logits可以自动去计算softmax
#labels去做onehot
#ylogy_
loss=tf.losses.sparse_softmax_cross_entropy(labels=y,logits=y_)


"""
p_y_1=tf.nn.sigmoid(y_) 
y_reshape=tf.reshape(y,(-1,1))
y_reshape_float=tf.cast(y_reshape,tf.float32)
loss=tf.reduce_mean(tf.square(y_reshape_float-p_y_1))
"""
predict=tf.argmax(y_,1)
correct_prediction=tf.equal(predict,y)
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float64))

with tf.name_scope('train_op'):
    train_op=tf.train.AdamOptimizer(1e-3).minimize(loss)

#进行训练

init=tf.global_variables_initializer()
batch_size=20
train_steps=10000
test_steps=100

with tf.Session() as sess:
    sess.run(init)
    for i in range(train_steps):
        batch_data,batch_labels=train_data.next_batch(batch_size)
        loss_val,accu_val,_=sess.run([loss,accuracy,train_op],
                                     feed_dict={x: batch_data,y: batch_labels})
        if (i+1)%500==0:
            print('[Train] Step: %d, loss: %4.5f, acc: %4.5f' % (i+1,loss_val,accu_val))
        if (i+1)%5000==0:
            test_data=CifarData(test_filenames,False)
            all_test_acc_val=[]
            for j in range(test_steps):
                test_batch_data,test_batch_labels=test_data.next_batch(batch_size)
                test_acc_val=sess.run([accuracy],feed_dict={x: test_batch_data,y: test_batch_labels})
                all_test_acc_val.append(test_acc_val)
            test_acc=np.mean(all_test_acc_val)
            print('[Test] Step: %d,acc: %4.5f' % (i+1,test_acc))

[Train] Step: 500, loss: 11.14932, acc: 0.20000
[Train] Step: 1000, loss: 14.60883, acc: 0.15000
[Train] Step: 1500, loss: 13.13694, acc: 0.05000
[Train] Step: 2000, loss: 7.39072, acc: 0.30000
[Train] Step: 2500, loss: 11.82805, acc: 0.20000
[Train] Step: 3000, loss: 7.93388, acc: 0.35000
[Train] Step: 3500, loss: 11.55873, acc: 0.10000
[Train] Step: 4000, loss: 11.87027, acc: 0.05000
[Train] Step: 4500, loss: 6.54007, acc: 0.30000
[Train] Step: 5000, loss: 4.06046, acc: 0.25000
(10000, 3072)
(10000,)
[Test] Step: 5000,acc: 0.26400
[Train] Step: 5500, loss: 4.96068, acc: 0.25000
[Train] Step: 6000, loss: 6.35689, acc: 0.15000
[Train] Step: 6500, loss: 7.28017, acc: 0.25000
[Train] Step: 7000, loss: 6.17465, acc: 0.25000
[Train] Step: 7500, loss: 5.71366, acc: 0.30000
[Train] Step: 8000, loss: 4.97062, acc: 0.30000
[Train] Step: 8500, loss: 7.81711, acc: 0.15000
[Train] Step: 9000, loss: 5.15836, acc: 0.25000
[Train] Step: 9500, loss: 5.21965, acc: 0.35000
[Train] Step: 10000, loss: 10.47672, acc: 0.10000
(10000, 3072)
(10000,)
[Test] Step: 10000,acc: 0.28000

 

 

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值