最近看到tensorflow训练cifar10数据集,说实话相比于mnist数据集,cifar10有了一个质的飞跃,从单通道灰度图像转变到三通道彩色图像。
cifar10
下面来简单介绍下cifar10数据集,该数据集共有60000张彩色图像,这些图像是32*32*3,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。Tensorflow自带有cifar的例子,可以在线下载cifar数据集,也可以离线下载,然后读取数据,在这里主要讲解如何搭建训练工程。下面请看代码:
import cifar10,cifar10_input
import tensorflow as tf
import numpy as np
import time
max_steps = 3000
batch_size = 128
data_dir = 'C:\\Users\\new\\Desktop\\cifar-10-batches-bin'
def variable_with_weight_loss(shape, stddev, wl):
var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
if wl is not None:
weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
tf.add_to_collection('losses', weight_loss)
return var
def loss(logits, labels):
# """Add L2Loss to all the trainable variables.
# Add summary for "Loss" and "Loss/avg".
# Args:
# logits: Logits from inference().