Stay fool & simple

落花无言 人淡如茶

cifar10_resNet

#-*-codingLutf-8-*-
import random
from  __future__ import print_function
from  __future__ import division
from  __future__ import absolute_import
from notmnist.dataset import randomize
import tensorflow as tf
from dataset_cifar10 import load_data
from mnist import residual_network
import os

def resnet():
    summaries_dir=''
    data_dir=''

    (train_dataset,train_labels),(test_dataset,test_labels)=load_data()
    train_dataset,train_labels=randomize(train_dataset,train_labels)#随机化
    with tf.name_scope('input'):
        dataset=tf.placeholder(tf.float32,[None,32,32,3],name='x_input')
        labels=tf.placeholder(tf.float32,[None,10],name='y_input')
    sess=tf.InteractiveSession()
    logits=residual_network.ResNet(dataset,10) #残差神经网络

    with tf.name_scope('total'):
        cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits
                                     (logits=logits,labels=labels))#交叉熵,损失函数
        tf.summary.scalar('cross_entropy',cross_entropy)

    with tf.name_scope('train'):
        train_step=tf.train.AdamOptimizer().minimize(cross_entropy) #优化器降低损失
    with tf.name_scope('accuracy'):
        with tf.name_scope('corret_prediction'):
            correct_pre=tf.equal(tf.argmax(labels,1),tf.argmax(logits,1))##tf.argmax(y_conv, 1)最有可能的分类,tf.argmax(labels,1)真实的标签 ,得到一组bool的列表
        with tf.name_scope('accuracy'):
            accuracy=tf.reduce_mean(tf.cast(correct_pre,tf.float32)) #计算准确度
            tf.summary.scalar('accuracy',accuracy)
    global_step=tf.Variable(0,name='global_step',trainable=False)  #在默认图中创建变量节点

    merged=tf.summary.merge_all()
    saver=tf.train.Saver()
    train_writer=tf.summary.FileWriter(summaries_dir+'/train',sess.graph)
    valid_writer=tf.summary.FileWriter(summaries_dir+'/valid')
    batch_size=50
    sess.run(tf.global_variables_initializer())
    print('initialized')

    ckpt_dir=summaries_dir+'/model'                     #保存训练模型路径
    ckpt=tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        print(ckpt.model_checkpoint_path)
        saver.restore(sess,ckpt.model_checkpoint_path)
    start=global_step.eval()
    print('start from',start)

    for step in range(start,20000):                      #训练批量大小
        offset=(step*batch_size)%(train_labels.shape[0]-batch_size)
        batch_data=train_dataset[offset:(offset+batch_size),:,:,:]
        batch_labels=train_labels[offset:(offset+batch_size),:]
        train_feed_dict={dataset:batch_data,labels:batch_labels}

        if(step%50==0):
            summary,acc=sess.run([merged,accuracy],feed_dict=train_feed_dict)
            print('step %d,Validation accuracy :%g'%(step,acc))
            valid_writer.add_summary(summary,step)

        if(step+1)%500==0 or (step+1)==20000:
            checkpoint_file=os.path.join(ckpt_dir,'model.ckpt')           #保存训练模型
            global_step.assign(step).eval()     #
            saver_path=saver.save(sess,checkpoint_file,global_step=global_step)
            print('save model to:',saver_path)
        summary,_,acc=sess.run([merged,train_step,accuracy],feed_dict=train_feed_dict)
        train_writer.add_summary(summary,step)

    batch=500
    acc=[]
    for i in  range(10):                                                     #测试集准确率
        offset=random.randint(0,9500)
        acc.append(accuracy.eval(feed_dict={dataset:test_dataset[offset:(offset+batch),:,:,:],
                                            labels:test_labels[offset:(offset+batch),:]}))
        print('test accuracy %g for batch %d' % (acc[i],i))
    print('average test accuracy',sum(acc)*1.0/len(acc))

    train_writer.close()
    valid_writer.close()

if __name__=='__main__':
    resnet()



这几天都连不上网,就我的电脑连不上,网卡是没有问题的,可能是人品出现了差错。面壁思过。
残差神经网络就这样?!!!

阅读更多
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hensonwells/article/details/77073691
个人分类: Coding DIARY tensorflow
想对作者说点什么? 我来说一句

cifar10_resnet.py例子

2018年01月28日 14KB 下载

没有更多推荐了,返回首页

加入CSDN,享受更精准的内容推荐,与500万程序员共同成长!
关闭
关闭