tensorflow 学习之 cifar_10 模型定义(补)

53 篇文章 1 订阅
33 篇文章 1 订阅
# -*- coding: utf-8 -*-
import  os
import  tensorflow as  tf
import  new_cifar10_input
import sys
import tarfile
import urllib

FLAGS=tf.app.flags.FLAGS  #解析命令行传递的参数

#设置模型参数
tf.app.flags.DEFINE_integer('batch_size',128,"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir','/tmp/cifar10_data',"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16',False,"""Train the model using fp16.""")

#数据集的全局常量
IMAGE_SIZE =new_cifar10_input.IMAGE_SISE
NUM_CLASSES =new_cifar10_input.NUM_CLASSES
NUM_EXAMOLES_PER_EPOCH_FOR_TRAIN =new_cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
NUM_EXAMOLES_PER_EPOCH_FOR_EVAL = new_cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

#训练的常量
MOVING_AVERAGE_DEVAY=0.999  #移动平均衰减率
NUM_EPOCHS_PER_DECAY=350.0   #衰减呈阶梯函数,控制衰减周期(阶梯宽度)  每350epoch衰减一次
LEARNING_RATE_DECAY_FACTOR=0.1 #学习率衰减因子
INITIAL_LEARNING_RATE=0.1      #初始化学习率

TOWER_NAME='tower'

DATA_URL='http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'


#创建直方图,以及衡量稀疏度的量,在tensorboard展现出来
def _activation_summary(x):
    tensor_name=re.sub('%s_[0-9]*/'%TOWER_NAME,'',x.op.name)
    tf.summary.histogram(tensor_name+'/activations',x)
    tf.summary.scalar(tensor_name+'/sparity',tf.nn.zero_fraction(x))


def _variable_on_cpu(name,shape,initializer):
    with tf.float16('/cup:0'):  # #一个 context manager,用于为新的op指定要使用的硬件
        dtype=tf.float16 if FLAGS.use_fp16 else tf.float32
        var=tf.get_variable(name,shape,initializer=initializer,dtype=dtype)
    return  var

def _variable_with_weight_decay(name,shape,stddev,wd):
    dtype=tf.float16 if FLAGS.use_fp16 else tf.float32
    var=_variable_on_cpu(name,shape,tf.truncated_normal_initializer(stddev=stddev,dtype=dtype))

    if wd is not None:
        weight_decay=tf.multiply(tf.nn.l2_loss(var),wd,name='weight_loss')
        tf.add_to_collection('losses',weight_decay)
    return var

def distorted_inputs():
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir =os.path.join(FLAGS.data_dir,'cifar-10-batches-bin')
    images,lables=new_cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=FLAGS.batch_size)

    if FLAGS.use_fp16:
        images=tf.cast(images,tf.float16)
        lables=tf.cast(lables,tf.float16)
    return  images,lables

def inputs(eval_data):
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir =os.path.join(FLAGS.data_dir,'cifar-10-batches-bin')
    images,labels=new_cifar10_input.inputs(eval_data=eval_data,data_dir=data_dir,batch_size=batch_size)

    if FLAGS.use_fp16:
        images=tf.cast(images,tf.float16)
        labels=tf.cast(labels,tf.float16)
    return images,labels

def inference(images):
    #卷积和池化第一层
    with tf.variable_scope('conv1') as scope:
        kernel=_variable_with_weight_decay('weights',shape=[5,5,3,64],stddev=5e-2,wd=0.0)
        conv=tf.nn.conv2d(images,kernel,[1,1,1,1],padding='SAME')
        biases=_variable_on_cpu('biases',[64],tf.constant_initializer(0.0))
        pre_activation=tf.nn.bias_add(conv,biases)
        conv1=tf.nn.relu(pre_activation,name=scope.name)
        _activation_summary(conv1)

    pool1=tf.nn.max_pool(conv1,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1')

    norm1=tf.nn.lrn(pool1,4,bias=1.0,alpha=0.001/9.0,beta=0.75,name='norm1')

    #卷积和池化第二层
    with tf.variable_scope('conv2') as  scope:
        kernel=_variable_with_weight_decay('weights',shape=[5,5,64,64],stddev=5e-2,wd=0.0)
        conv=tf.nn.conv2d(norm1,kernel,[1,1,1,1],padding='SAME')
        biases=_variable_on_cpu('biases',[64],tf.constant_initializer(0.1))
        pre_activation=tf.nn.bias_add(conv,biases)
        conv2=tf.nn.relu(pre_activation,name=scope.name)
        _activation_summary(conv2)

    norm2=tf.nn.lrn(conv2,4,bias=1.0,alpha=0.001/9.0,beta=0.75,name='norm2')
    pool2=tf.nn.max_pool(norm2,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool2')
    #全连接层
    with tf.variable_scope('fc1') as  scope:
        reshape=tf.reshape(pool2,[FLAGS.batch_size,-1])
        dim=reshape.get_shape()[1].value
        weights=_variable_with_weight_decay('weights',shape=[dim,384],stddev=0.04,wd=0.004)
        biases=_variable_on_cpu('biases',[384],tf.constant_initializer(0.1))
        fc1=tf.nn.relu(tf.matmul(reshape,weights)+biases,name=scope.name)
        _activation_summary(fc1)

    with tf.variable_scope('fc2') as  scope:
        weights=_variable_with_weight_decay('weights',shape=[384,192],stddev=0.04,wd=0.004)
        biases=_variable_on_cpu('biased',[192],tf.constant_initializer(0.1))
        fc2=tf.nn.relu(tf.matmul(fc1,weights)+biases,name=scope.name)
        _activation_summary(fc2)

    #进行线性变换输出logistics模型
    with tf.variable_scope('sotfmax_linear') as  scope:
        weights=_variable_with_weight_decay('weights',[192,NUM_CLASSES],stddev=1/192.0,wd=0.0)
        biases=_variable_on_cpu('biases',[NUM_CLASSES],tf.constant_initializer(0.0))
        softmax_linear=tf.add(tf.matmul(fc2,weights),biases,name=scope.name)
        _activation_summary(softmax_linear)
    return softmax_linear


def loss(logits,labels): # labels,其值是稀疏表示的  logits,其表示隐藏层线性变换后非归一化后的结果
    labels=tf.cast(labels,tf.int64)
    cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits=logits,    #根据稀疏表示的label和输出层数据计算损失
                                                                 name='cross_entropy_per_example')
    cross_entropy_mean=tf.reduce_mean(cross_entropy,name='cross_entropy')
    tf.add_to_collection('losses',cross_entropy_mean)
    return tf.add_n(tf.get_collection('losses'),name='total_loss')

def _add_loss_summaries(total_loss):
    # MovingAverage为滑动平均
    # 计算方法:对于一个给定的数列,首先设定一个固定的值k,然后分别计算第1项到第k项,第2项到第k+1项,第3项到第k+2项的平均值,依次类推
    loss_averages=tf.train.ExponentialMovingAverage(0.9,name='avg')



    losses=tf.get_collection('losses')   #从字典集合中返回关键字'losses'对应的所有变量,包括交叉熵损失和正则项损失
    loss_averages_op=loss_averages.apply(losses+[total_loss])

    for l in  losses+[total_loss]:
        tf.summary.scalar(l.op.name +'(raw)',l)
        tf.summary.scalar(l.op.name,loss_averages.average(l))
    return  loss_averages_op


def train(total_loss,global_step):
    #影响学习速率的变量
    num_batched_per_epoch=NUM_EXAMOLES_PER_EPOCH_FOR_TRAIN/FLAGS.batch_size
    decay_steps=int(num_batched_per_epoch*NUM_EPOCHS_PER_DECAY)
    ##根据步数以指数方式衰减学习率。
    lr=tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,decay_steps,
                                  LEARNING_RATE_DECAY_FACTOR,staircase=True)
    tf.summary.scalar('learning_rate',lr)
    #生成所有损失的平均值
    loss_averages_op=_add_loss_summaries(total_loss)
    #计算梯度
    with tf.control_dependencies(loss_averages_op):
        opt=tf.train.GradientDescentOptimizer(lr)
        grads=opt.compute_gradients(total_loss)
    apply_gradient_op=opt.apply_gradients(grads,global_step=global_step)    #应用梯度

    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name,var)  #训练变量直方图

    for grad,var in grads:
        if grad is not None:
            tf.summary.histogram(var.op.name+'/gradients',grad)  #梯度直方图

    #跟踪所有的训练变量的移动平均值
    variable_averages=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DEVAY,global_step)
    variable_averages_op=variable_averages.apply(tf.trainable_variables())
    with tf.control_dependencies([apply_gradient_op,variable_averages_op]):
        train_op=tf.no_op(name='train')
    return train_op

def maybe_download_and_extract():
    dest_directory=FLAGS.data_dir
    if not os.path.exists(dest_directory):
        os.makedirs(dest_directory)
    filename=DATA_URL.split('/')[-1]
    filepath=os.path.join(dest_directory,filename)
    if not os.path.exists(filepath):
        def _progress(count,block_size,total_size):
            sys.stdout.write('\r >>Downloading %.1f%%'%(filename,
                                                        float(count*block_size)/float(total_size)*100.0))
            sys.stdout.flush()
        filepath,_=urllib.request.urlretrieve(DATA_URL,filepath,_progress)
        print()
        statinfo=os.stat(filepath)
        print('Successfully download',filename,statinfo.st_size,'bytes.')
    extracted_dir_path=os.path.join(dest_directory,'cifar-10-batches-bin')
    if not  os.path.exists(extracted_dir_path):
        tarfile.open(filepath,'r:gz').extractall(dest_directory)



很多不理解的地方,得去学习API

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值