图像分类之花卉图像分类(四)训练模型

本来我想用tensorbaord来观察LOSS曲线变化的,但是我代码改得不对,如果有小伙伴改出来了,如果可以的话可以告诉我,我懒得改了。下面代码也是注意改成自己的路径

# 导入文件
import os
import numpy as np
import tensorflow as tf
import input_data
import model
import os
import time
import warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings('ignore')

# 变量声明
N_CLASSES = 5  # 五种花类型
IMG_W = 64  # resize图像,太大的话训练时间久
IMG_H = 64
BATCH_SIZE = 25
CAPACITY = 250
MAX_STEP =5000
learning_rate = 0.0005# 一般小于0.0001


train_dir = 'D:/flower_photos/input_data2/train'  # 训练样本的读入路径
val_dir = 'D:/flower_photos/input_data2/val'  # 验证样本的读入路径
logs_train_dir = 'D:/save2/train'  # logs存储路径
logs_val_dir = 'D:/save2/val'

train, train_label= input_data.get_files(train_dir)
val, val_label = input_data.get_files(val_dir)
# 训练数据及标签
train_batch, train_label_batch = input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# 测试数据及标签
val_batch, val_label_batch = input_data.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)

x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_W, IMG_H, 3])
y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE])

# 存放DropOut参数的容器,训练时为0.45,测试时为0
dropout_placeholdr = tf.placeholder(tf.float32)
# 是否是训练状况
train = tf.placeholder(tf.float32)

logits = model.inference(x, BATCH_SIZE, N_CLASSES,dropout_placeholdr,train)
loss = model.losses(logits, y_)
acc = model.evaluation(logits, y_)
train_op = model.trainning(loss, learning_rate)

with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
    val_writer = tf.summary.FileWriter(logs_val_dir)
    # val_writer = tf.summary.FileWriter(logs_val_dir, sess.graph)

    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break
            tra_images, tra_labels = sess.run([train_batch, train_label_batch])
            _, tra_loss, tra_acc = sess.run([train_op, loss, acc],
                                            feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:0.45,train:1})
            if step % 100 == 0:
                print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
                summary_str = sess.run(summary_op, feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:0.45,train:1})
                train_writer.add_summary(summary_str, step)

            if step % 100 == 0:
                val_images, val_labels = sess.run([val_batch, val_label_batch])
                val_loss, val_acc = sess.run([loss, acc],
                                             feed_dict={x: val_images, y_: val_labels,dropout_placeholdr:1.0,train:0})
                print('** val loss = %.2f, val accuracy = %.2f%%  **' % (val_loss, val_acc * 100.0))
                summary_str = sess.run(summary_op, feed_dict={x: tra_images, y_: tra_labels,dropout_placeholdr:1.0,train:0})
                val_writer.add_summary(summary_str, step)

            if step % 100 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()
    coord.join(threads)

其中save文件夹中存储的就是训练好的模型,这个在后面测试的时候会用到。

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值