Tensorflow---使用Tensorflow进行多输出模型的构建预测种类与颜色

一、数据集可以通过以下链接下载

百度网盘提取码:lala

二、代码运行环境

Tensorflow-gpu==2.4.0

Python==3.7

三、构造种类与颜色的索引文件

import glob


# 此函数只执行一次
def write_index_text():
    image_path = glob.glob(r'dataset/*/*')
    class_names = set(path.split('\\')[1] for path in image_path)
    color_label_names = set(i.split('_')[0] for i in class_names)
    class_label_names = set(i.split('_')[1] for i in class_names)

    color_to_index = dict((name, index) for index, name in enumerate(color_label_names))
    class_to_index = dict((name, index) for index, name in enumerate(class_label_names))

    # 开始进行txt文件的写入
    with open(r'color_to_index.txt', 'w') as f:
        for item in color_to_index.items():
            file_line = ''
            for i in range(len(item)):
                file_line += str(item[i]) + ' '

            f.write(file_line)
            f.write('\n')

    with open(r'class_to_index.txt', 'w') as f:
        for item in class_to_index.items():
            file_line = ''
            for i in range(len(item)):
                file_line += str(item[i]) + ' '
            f.write(file_line)
            f.write('\n')


if __name__ == '__main__':
    write_index_text()

四、构建数据集的输入

import glob
import random
import tensorflow as tf
import os

# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


# 获取相应的字典
def get_label_to_index():
    with open('color_to_index.txt', 'r') as f:
        color_to_index = dict((line.split(' ')[0], int(line.split(' ')[1])) for line in f.readlines())
        index_to_color = dict((v, k) for (k, v) in color_to_index.items())
    with open('class_to_index.txt', 'r') as f:
        class_to_index = dict((line.split(' ')[0], int(line.split(' ')[1])) for line in f.readlines())
        index_to_class = dict((v, k) for k, v in class_to_index.items())

    return color_to_index, index_to_color, class_to_index, index_to_class


# 进行图片的解析
def load_images(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    image = image * 2 - 1
    return image


# 制作数据集
def make_dataset():
    co_to_index, _, cl_to_index, _ = get_label_to_index()
    all_images_path = glob.glob('dataset/*/*')
    random.shuffle(all_images_path)
    all_colors_label = [co_to_index.get(i.split('\\')[1].split('_')[0]) for i in all_images_path]
    all_classes_label = [cl_to_index.get(i.split('\\')[1].split('_')[1]) for i in all_images_path]

    image_dataset = tf.data.Dataset.from_tensor_slices(all_images_path)
    image_dataset = image_dataset.map(load_images, tf.data.experimental.AUTOTUNE)

    label_dataset = tf.data.Dataset.from_tensor_slices((all_colors_label, all_classes_label))

    dataset = tf.data.Dataset.zip((image_dataset, label_dataset))

    count = len(all_images_path)
    test_count = int(count * 0.2)
    train_count = count - test_count

    train_dataset = dataset.skip(test_count)
    test_dataset = dataset.take(test_count)

    BATCH_SIZE = 32

    train_dataset = train_dataset.shuffle(train_count).repeat(-1)
    train_dataset = train_dataset.batch(BATCH_SIZE)
    train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

    test_dataset = test_dataset.batch(BATCH_SIZE)

    return train_dataset, test_dataset, train_count, test_count


if __name__ == '__main__':
    train_data, test_data, train_ct, test_ct = make_dataset()
    print(train_data)
    print(test_data)

五、进行模型的构建

import tensorflow as tf
import os

# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


def make_model():
    mobile_net = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False)
    inputs = tf.keras.Input(shape=(224, 224, 3))
    x = mobile_net(inputs)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)

    x1 = tf.keras.layers.Dense(1024, activation='relu')(x)
    output_color = tf.keras.layers.Dense(3, activation='softmax', name='output_color')(x1)

    x2 = tf.keras.layers.Dense(1024, activation='relu')(x)
    output_class = tf.keras.layers.Dense(4, activation='softmax', name='output_class')(x2)

    model = tf.keras.Model(inputs=inputs,
                           outputs=[output_color, output_class])

    model.summary()
    return model


if __name__ == '__main__':
    make_model()

六、进行模型的训练

import tensorflow as tf
import os
from data_loader import make_dataset
from model import make_model

# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# 加载训练数据
train_dataset, test_dataset, train_count, test_count = make_dataset()

# 加载训练的模型
my_model = make_model()

# 模型的相关配置
my_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss={
        'output_color': 'sparse_categorical_crossentropy',
        'output_class': 'sparse_categorical_crossentropy'
    },
    metrics=['acc']
)

train_steps = train_count // 32
test_steps = test_count // 32
tf_tensorboard = tf.keras.callbacks.TensorBoard('logs', histogram_freq=1)

# 模型的训练
my_model.fit(train_dataset,
             epochs=100,
             steps_per_epoch=train_steps,
             validation_data=test_dataset,
             validation_steps=test_steps,
             callbacks=[tf_tensorboard])

# 模型的保存
my_model.save(r'model_data/model.h5')

七、进行模型的预测

import numpy
import tensorflow as tf
import os
from data_loader import get_label_to_index
import matplotlib.pyplot as plt

# 环境变量的控制
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# 模型的加载
my_model = tf.keras.models.load_model(r'model_data/model.h5')

# 索引的加载
color_to_index, index_to_color, class_to_index, index_to_class = get_label_to_index()

while True:
    path = input('请输入图片路径:')
    try:
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
    except:
        print('文件路径输入错误!')
        continue
    else:
        inputs = tf.image.resize(image, [224, 224])
        inputs = tf.cast(inputs, tf.float32)
        inputs = inputs / 255.0
        inputs = inputs * 2 - 1
        inputs = tf.expand_dims(inputs, 0)
        pre = my_model.predict(inputs)
        plt.title(
            'It is a ' + index_to_color.get(numpy.argmax(pre[0])) + ' ' + index_to_class.get(numpy.argmax(pre[1])))
        plt.imshow(image)
        plt.savefig('result.jpg')
        plt.show()

八、预测结果的展示

预测结果展示

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

水哥很水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值