Tensorflow训练CIFAR10源代码

本文介绍了使用Tensorflow训练CIFAR10数据集的过程,从数据集的特点到训练工程的搭建,详细阐述了从彩色图像处理到模型训练的步骤。通过实例代码展示了如何在Tensorflow中加载和预处理CIFAR10数据,以及如何构建和运行训练过程。
摘要由CSDN通过智能技术生成

最近看到tensorflow训练cifar10数据集,说实话相比于mnist数据集,cifar10有了一个质的飞跃,从单通道灰度图像转变到三通道彩色图像。

cifar10

下面来简单介绍下cifar10数据集,该数据集共有60000张彩色图像,这些图像是32*32*3,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。Tensorflow自带有cifar的例子,可以在线下载cifar数据集,也可以离线下载,然后读取数据,在这里主要讲解如何搭建训练工程。下面请看代码:

import cifar10,cifar10_input
import tensorflow as tf
import numpy as np
import time

max_steps = 3000
batch_size = 128
data_dir = 'C:\\Users\\new\\Desktop\\cifar-10-batches-bin'


def variable_with_weight_loss(shape, stddev, wl):
    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
    if wl is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
        tf.add_to_collection('losses', weight_loss)
    return var


def loss(logits, labels):
#      """Add L2Loss to all the trainable variables.
#      Add summary for "Loss" and "Loss/avg".
#      Args:
#        logits: Logits from inference().
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值