弱者用泪水安慰自己,强者用汗水磨练自己。
上一篇文章里面讲了使用TensorFlow做手写数字图像识别,这篇文章算是它的进阶篇吧,在本篇文章中将会讲解如何使用TensorFlow识别多种类图片。本次使用的数据集是CIFAR-10,这是一个比较经典的数据集,可以去百度一下它的官网,它包含60000张32X32的彩色图像,其中训练集50000张,测试集10000张。里面一共是10类的图片,分别是airplane、automobile、bird、cat、deer、dog、frog、horse、ship和truck。
第一步我们需要下载TensorFlow Models库,你可以去github上面下载也可以使用git指令下载
git clone https://github.com/tensorflow/models.git
导入库,定义batch_size、训练轮数max_steps,以及下载CIFAR-10的路径
from tensorflow.models.tutorials.image.cifar10 import cifar10, cifar10_input
import tensorflow as tf
import numpy as np
import time
max_steps=3000
batch_size=128
data_dir='/cifar10_data'
定义初始化weight的函数,使用tf.truncated_normal截断的正太分布,给weight加一个L2的loss,L2正则化可以帮助我们筛选出最有效的特征。使用w1控制L2 loss的大小,使用tf.nn.l2_loss函数计算weight的L2 loss,再使用tf.multiply让L2 loss乘以w1,得到最后的weight loss,使用tf.add_to.collection把weight loss统一存到一个collection并命名为losses,以后计算神经网络总体的loss会用。
def variable_with_weight_loss(shape,stddev,w1):
var = tf.Variable(tf.truncated_normal(shape,stddev=stddev))
if w1 is not None:
weight_loss=tf.multiply(tf.nn.l2_loss(var),w1,name='weight_loss')
tf.add_to_collection('losses',weight_loss)
return var
使用cifar10来下载数据集,再使用cirfar10_input中的distorted_inputs函数产生训练需要使用的数据,