猫狗识别


# coding: utf-8

# 定义功能函数

# VGG16:输入层——卷积——卷积——池化——卷积——卷积——池化——卷积——卷积——卷积——池化——卷积——卷积——卷积——池化——卷积——卷积——卷积——池化——全连接——全连接——全连接——softmax  

# In[1]:

import os
import tensorflow as tf
from time import time
import VGG16_model as model
import utils#定义了我们所用到的功能函数
from scipy.misc import imread,imresize
import numpy as np


# In[2]:

startTime=time()
batch_size=32
capacity=180#内存中存储的最大数据容量,根据自己的电脑配置而定
means=[123.68,116.779,103.939]#VGG训练时图像预处理所减均值(RGB三通道)
epoch=tf.Variable(0,name='epoch',trainable=False)#这个是不可训练的,相当于一个断点值,执行断点续训
sess=tf.Session()#声明会话
init=tf.global_variables_initializer()#调用变量
sess.run(init)#运行变量


# In[3]:

#设置检查点存储目录
ckpt_dir='./model/'
if not os.path.exists(ckpt_dir):#如果目录下不存在ckpt_dir
    os.makedirs(ckpt_dir)#创建ckpt_dir文件
saver=tf.train.Saver(max_to_keep=1)#生成saver,用于保存和提取变量
#如果有检查点文件,读取最新检查点文件,恢复各种变量值
ckpt=tf.train.latest_checkpoint(ckpt_dir)
#创建summary_writer,用于写图文件
summary_writer=tf.summary.FileWriter(ckpt_dir,sess.graph)
#如果有检查点文件,恢复检查点文件,恢复各种变量值
ckpt=tf.train.latest_checkpoint(ckpt_dir)
#saver.restore(sess,'./model/')#恢复最后保存的模型
if ckpt !=None:
    saver.restore(sess,ckpt)#加载所有的参数
    #从这里开始就可以直接使用模型进行预测,或者接着继续训练了
else:
    print('training from scratch')
#获取训练参数
start=sess.run(epoch)
print('traing starts from {} epoch'.format(start+1))


# In[ ]:

xs,ys=utils.get_file('data/train/')#获取图像列表和标签列表
image_batch,label_batch=utils.get_batch(xs,ys,224,224,batch_size,capacity)#通过读取列表来载入批量图片及标签
x=tf.placeholder(tf.float32,[None,224,224,3])
y=tf.placeholder(tf.int32,[None,2])
vgg=model.vgg16(x)#输出模型
fc8_finetuining=vgg.probs#即sofemax(fc8)微调(finetuining)sofemax(fc8)
loss_function=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=fc8_finetuining,labels=y))
optimizer=tf.train.GradientDescentOptimizer(0.001).minimize(loss_function)#GradientDescentOptimizer下降优化器
sess=tf.Session()#声明会话
init=tf.global_variables_initializer()#调用变量
sess.run(init)#运行变量
vgg.load_weights('vgg16_weights.npz',sess)#通过npz格式的文件获取VGG的相应权重参数,从而将权重注入即可实现复用
saver=tf.train.Saver()#生成saver,用于保存和提取变量
print('Model restoting......')
#saver.restore(sess,'./model/')#恢复最后保存的模型
#saver.restore(sess,'.model/epoch_00800.ckpt')恢复指定检查点的模型
#print('traing starts from {} epoch'.format(start+1))


coord=tf.train.Coordinator()#使用协调器Coordinator来管理线程
threads=tf.train.start_queue_runners(coord=coord,sess=sess)
epoch_start_time=time()
for i in range(start,1000):
    images,labels=sess.run([image_batch,label_batch])
    labels=utils.onehot(labels)#用one-hot对标签进行编码
    sess.run(optimizer,feed_dict={x:images,y:labels})
    loss=sess.run(loss_function,feed_dict={x:images,y:labels})
    print('现在的损失为:%f'%loss)
    epoch_end_time=time()
    print('当前训练花费的时间:',(epoch_end_time-epoch_start_time))
    epoch_start_time=epoch_end_time
      #保存检查点
    saver.save(sess,os.path.join('model/','epoch{:06d}.ckpt'.format(i)), global_step=i+1)
    sess.run(epoch.assign(i+1))
    print('===============Epoch %d is finished==============='%i)

#模型保存
#saver.save(sess,'./model/')
print('Optimization Finished!')
duration=time()-startTime
print('训练完成花费的时间:','{:.2f}'.format(duration))

coord.request_stop()#通知其它线程关闭
coord.join(threads)#join操作等待其他线程结束,其他所有线程关闭之后,这一函数才能返回


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值