18.4:tensorflow分类模型mobilenetv2训练-图像转tfrecord格式并读取批量数据进行训练

    通过18.1-18.3已经可以训练一个完整的分类模型并对图像进行预测了。之前的是从文件夹直接读取图像文件进行训练,这里介绍如何把图像转为tfrecord格式并读取、训练。

一、图像转tfrecord文件

如下为jpg_to_tfrecord.py。

函数create_label_map()用于创建从图像类别到分类用的label的对应关系,保存在label.txt中,如果自己手动建立(参考18.1)这个文件则注释掉124行不使用该函数。函数read_label_map()用于读取label.txt。

jpg_to_record()函数用于把图像转为tfrecord。

get_batch_record()函数用于获取一个批量数据。

use_record()函数建立session并运行批量数据,把得到的图像数据进行保存。

#coding:utf-8
import os 
import glob
import cv2
import numpy as np
import tensorflow as tf


# jpg to tfrecord


def create_label_map(data_path, label_dir):
    classes = os.listdir(data_path)
    for cls in classes:
        if os.path.isfile(os.path.join(data_path, cls)):
            #print("%s is a file, not a folder" %(os.path.join(data_path, cls)))
            classes.remove(cls)
    print("the class number is: %d, they are: %s" %(len(classes), classes))
    f = open(label_dir, "w")
    print("folder:label")
    for index, name in enumerate(classes):
        print("%s:%s" % (str(name), str(index)))
        f.write("%s:%s" % (str(name), str(index)) + "\n")
    f.close()
    print("creat %s done..." % label_dir)


def read_label_map(label_dir):
    label_dict, label_dict_res = {}, {}
    with open(label_dir, "r") as f:
        for line in f.readlines():
            folder, cls = line.strip().split(":")[0], line.strip().split(":")[1]
            label_dict[folder] = cls
            label_dict_res[cls] = folder
    return label_dict, label_dict_res


def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))


def jpg_to_record(data_path, save_path):
    label_dict, label_dict_res = read_label_map(label_dir)
    writer = tf.python_io.TFRecordWriter(save_path)
    a = 0
    for name, index in label_dict.items():
        a += 1
        print("%3d/%d, image to tfrecord, processing %s:%s" % (a, len(label_dict), name, index))
        img_list = glob.glob(os.path.join(data_path, name, "*"+img_extend))
        if len(img_list) == 0:
            print('cannot find "%s" image in %s' % (img_extend, os.path.join(data_path, name)))
            break
        for img_path in img_list:
            img = cv2.imread(img_path)
            img = cv2.resize(img, (resize, resize))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_raw = img.tobytes()
            example = tf.train.Example(features = tf.train.Features(feature = {
                                                                               "label": _int64_feature(int(index)),
                                                                               "image": _bytes_feature(img_raw),
                                                                               }))
            writer.write(example.SerializeToString())  #序列化为字符串
    writer.close()
    print("jpg to tfrecord done...")


def get_batch_record(record_name, batch_size):
    filename_queue = tf.train.string_input_producer([record_name])  # make a queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # return file_name and file
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'image' : tf.FixedLenFeature([], tf.string)})
    # get img data
    img = tf.decode_raw(features['image'], tf.uint8)
    img = tf.reshape(img, [resize, resize, 3])  #reshape image
    # get label
    label = tf.cast(features['label'], tf.int32)
    # get batch data
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                            batch_size= batch_size,
                                            num_threads=64,
                                            capacity=2000,
                                            min_after_dequeue=1500)
    return img_batch, label_batch


def use_record(tfrecords_file, BATCH_SIZE):
    label_dict, label_dict_res = read_label_map(label_dir)
    image_batch, label_batch = get_batch_record(tfrecords_file, BATCH_SIZE)
    with tf.Session() as sess:
        i = 0
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop() and i<3:
                image, label = sess.run([image_batch, label_batch])
                print(label)
                for j in np.arange(BATCH_SIZE):
                    print('label: %d, truth label: %s' % (label[j], label_dict_res[str(label[j])]))
                    img = cv2.cvtColor(image[j], cv2.COLOR_BGR2RGB)
                    cv2.imwrite("batch:"+str(i) + "_img_" + str(j) + ".jpg", img)
                i+=1
        except tf.errors.OutOfRangeError:
            print('done!')
        finally:
            coord.request_stop()
        coord.join(threads)


if __name__ == "__main__":
    # save name
    tfrecords_file = 'train.tfrecords'
    label_dir = "label.txt"
    # params setting
    img_extend = ".jpg"
    resize = 224
    data_path = "/home/ming/data/flower/train"
    create_label_map(data_path, label_dir)
    jpg_to_record(data_path, tfrecords_file)
    use_record(tfrecords_file=tfrecords_file, BATCH_SIZE = 4)

二、使用tfrecord文件训练,train_tfrecord.py

训练程序和之前相同,只是获取批量数据的函数改变了。

#coding:utf-8
import os
import numpy as np
import tensorflow as tf

import model


label_dict, label_dict_res = {}, {}
# use label.txt created by jpg_to_tfrecord.py
with open("label.txt", 'r') as f:
    for line in f.readlines():
        folder, label = line.strip().split(':')[0], line.strip().split(':')[1]
        label_dict[folder] = label
        label_dict_res[label] = folder
print(label_dict)

train_record_dir = "/home/ming/data/flower/train.tfrecords"
logs_train_dir = './model_save'
train_number = 1000  #数据集图像总数
init_lr = 0.1
BATCH_SIZE = 64
one_epoch_step = train_number / BATCH_SIZE  # 1个epoch迭代次数
decay_steps = 20*one_epoch_step  #20个epoch学习率衰减一次
MAX_STEP = 100*one_epoch_step  #迭代100个epoch
N_CLASSES = len(label_dict)
IMG_W = 224
IMG_H = 224
CAPACITY = 1000 + 3 * BATCH_SIZE
display_step = 100
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # gpu编号
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # 设置最小gpu使用量


def get_batch_record(record_name, IMG_W, IMG_H, batch_size, CAPACITY):
    filename_queue = tf.train.string_input_producer([record_name])  # make a queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # return file_name and file
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'image' : tf.FixedLenFeature([], tf.string)})
    # get label and image data
    label = tf.cast(features['label'], tf.int32)
    image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(image, [IMG_H, IMG_W, 3])  #reshape image
    # 数据增强
    #image = tf.image.resize_image_with_pad(image, target_height=image_W, target_width=image_H)
    # image = tf.image.resize_images(image, (image_W, image_H))
    # random rotate 90
    if np.random.randn() > 0:
        image = tf.image.transpose_image(image)
    # 随机左右翻转
    image = tf.image.random_flip_left_right(image)
    # 随机上下翻转
    image = tf.image.random_flip_up_down(image)
    # 随机设置图片的亮度
    image = tf.image.random_brightness(image, max_delta=32/255.0)
    # 随机设置图片的对比度
    #image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
    # 随机设置图片的色度
    image = tf.image.random_hue(image, max_delta=0.05)
    # 随机设置图片的饱和度
    #image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
    # 标准化,使图片的均值为0,方差为1
    image = tf.image.per_image_standardization(image)
    # get batch data
    img_batch, label_batch = tf.train.shuffle_batch([image, label],
                                            batch_size= batch_size,
                                            num_threads=64,
                                            capacity=CAPACITY,
                                            min_after_dequeue=1000)
    tf.summary.image("input_img", img_batch, max_outputs=5)
    label_batch = tf.reshape(label_batch, [batch_size])
    img_batch = tf.cast(img_batch, tf.float32)
    return img_batch, label_batch


def main():
    global_step = tf.Variable(0, name='global_step', trainable=False)
    # label without one-hot
    batch_train, batch_labels = get_batch_record(train_record_dir, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
    # batch_train, batch_labels = get_batch(train,
    #                                       train_label,
    #                                       IMG_W,
    #                                       IMG_H,
    #                                       BATCH_SIZE,
    #                                       CAPACITY)
    # network
    logits = model.MobileNetV2(batch_train, num_classes=N_CLASSES, is_training=True).output
    #logits = model.model2(batch_train, BATCH_SIZE, N_CLASSES)
    #logits = model.model4(batch_train, N_CLASSES, is_trian=True)
    print logits.get_shape()
    # loss
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=batch_labels)
    loss = tf.reduce_mean(cross_entropy, name='loss')
    tf.summary.scalar('train_loss', loss)
    # optimizer
    lr = tf.train.exponential_decay(learning_rate=init_lr, global_step=global_step, decay_steps=decay_steps, decay_rate=0.1)
    tf.summary.scalar('learning_rate', lr)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss, global_step=global_step)
    # accuracy
    correct = tf.nn.in_top_k(logits, batch_labels, 1)
    correct = tf.cast(correct, tf.float16)
    accuracy = tf.reduce_mean(correct)
    tf.summary.scalar('train_acc', accuracy)
    
    summary_op = tf.summary.merge_all()
    sess = tf.Session(config=config)
    train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
    
    #saver = tf.train.Saver()
    var_list = tf.trainable_variables() 
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars
    saver = tf.train.Saver(var_list=var_list, max_to_keep=10)
    
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #saver.restore(sess, logs_train_dir+'/model.ckpt-174000') 
    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                    break
            _, learning_rate, tra_loss, tra_acc = sess.run([optimizer, lr, loss, accuracy])
            if step % display_step == 0:
                print('Epoch:%3d/%d, Step:%6d/%d, lr:%f, train loss:%.4f, train acc:%.2f%%' %(step/one_epoch_step+1, MAX_STEP/one_epoch_step, step+display_step, MAX_STEP, learning_rate, tra_loss, tra_acc*100.0))
                summary_str = sess.run(summary_op)
                train_writer.add_summary(summary_str, step)
            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()
        
    coord.join(threads)
    sess.close()
    

if __name__ == '__main__':
    main()

三、对图像预测方法和之前一样

下一篇:使用tensorRT加速模型https://blog.csdn.net/u010397980/article/details/86382849

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值