TensorFlow 训练模型流程解读(含源码)

本文详细介绍了使用TensorFlow进行花卉分类模型训练的过程,包括数据读取、模型搭建、训练及测试。在数据读取部分,通过get_files和get_batch函数处理图片和标签。模型搭建涉及变量域的使用。训练阶段创建会话并从文件队列读取数据。最后,文章讨论了如何测试训练好的模型并保存模型权重。
摘要由CSDN通过智能技术生成

使用TensorFlow 编写了一个对四种花分类的代码,其中涉及到读取数据,搭建模型,测试图片。在编写代码中有些API用的不是很熟练,因此写下此文章记录,方便日后回忆。

第一部分:读取数据和标签

一般我们训练模型,需要大量的数据,一般有数据集,但是一些特殊行业,需要自己手动收集数据。我们把数据分类放好,如图:

四种花的图片分门别类放好。我们就可以写代码读取,并标签0 1 2 3

代码如下:

train_dir = 'D:/download/flower_world-master/flower_world-master/input_data'  # 训练样本的读入路径
logs_train_dir = 'D:/download/flower_world-master/flower_world-master/log'  # logs存储路径

# train, train_label = input_data.get_files(train_dir)
train, train_label, val, val_label = input_data.get_files(train_dir, 0.3)
# 训练数据及标签
train_batch, train_label_batch = input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# 测试数据及标签
val_batch, val_label_batch = input_data.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# step1:获取所有的图片路径名,存放到
# 对应的列表中,同时贴上标签,存放到label列表中。
def get_files(file_dir, ratio):
    for file in os.listdir(file_dir + '/roses'):
        roses.append(file_dir + '/roses' + '/' + file)
        label_roses.append(0)
    for file in os.listdir(file_dir + '/tulips'):
        tulips.append(file_dir + '/tulips' + '/' + file)
        label_tulips.append(1)
    for file in os.listdir(file_dir + '/dandelion'):
        dandelion.append(file_dir + '/dandelion' + '/' + file)
        label_dandelion.append(2)
    for file in os.listdir(file_dir + '/sunflowers'):
        sunflowers.append(file_dir + '/sunflowers' + '/' + file)
        label_sunflowers.append(3)

    # step2:对生成的图片路径和标签List合并成一个大数组
    image_list = np.hstack((roses, tulips, dandelion, sunflowers))
    label_list = np.hstack((label_roses, label_tulips, label_dandelion, label_sunflowers))

    # 利用shuffle打乱顺序
    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)

    # 从打乱的temp中再取出list(img和lab)
    # image_list = list(temp[:, 0])
    # label_list = list(temp[:, 1])
    # label_list = [int(i) for i in label_list]
    # return image_list, label_list

    # 将所有的img和lab转换成list
    all_image_list = list(temp[:, 0])
    all_label_list = list(temp[:, 1])

    # 将所得List分为两部分,一部分用来训练tra,一部分用来测试val
    # ratio是测试集的比例
    n_sample = len(all_label_list)
    n_val = int(math.ceil(n_sample * ratio))  # 测试样本数
    n_train = n_sample - n_val  # 训练样本数

    tra_images = all_image_list[0:n_train]
    tra_labels = all_label_list[0:n_train]
    tra_labels = [int(float(i)) for i in tra_labels]
    val_images = all_image_list[n_train:-1]
    val_labels = all_label_list[n_train:-1]
    val_labels = [int(float(i)) for i in val_labels]

    return tra_images, tra_labels, val_images, val_labels


# ---------------------------------------------------------------------------
# --------------------生成Batch----------------------------------------------

# step1:将上面生成的List传入get_batch() ,转换类型,产生一个输入队列queue,因为img和lab
# 是分开的,所以使用tf.train.slice_input_producer(),然后用tf.read_file()从队列中读取图像
#   image_W, image_H, :设置好固定的图像高度和宽度
#   设置batch_size:每个batch要放多少张图片
#   capacity:一个队列最大多少
def get_batch(image, label, image_W, image_
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值