一、数据集可以通过以下链接下载
百度网盘提取码:lala
二、代码运行环境
Tensorflow-gpu==2.4.0
Python==3.7
三、构造种类与颜色的索引文件
import glob
def write_index_text():
image_path = glob.glob(r'dataset/*/*')
class_names = set(path.split('\\')[1] for path in image_path)
color_label_names = set(i.split('_')[0] for i in class_names)
class_label_names = set(i.split('_')[1] for i in class_names)
color_to_index = dict((name, index) for index, name in enumerate(color_label_names))
class_to_index = dict((name, index) for index, name in enumerate(class_label_names))
with open(r'color_to_index.txt', 'w') as f:
for item in color_to_index.items():
file_line = ''
for i in range(len(item)):
file_line += str(item[i]) + ' '
f.write(file_line)
f.write('\n')
with open(r'class_to_index.txt', 'w') as f:
for item in class_to_index.items():
file_line = ''
for i in range(len(item)):
file_line += str(item[i]) + ' '
f.write(file_line)
f.write('\n')
if __name__ == '__main__':
write_index_text()
四、构建数据集的输入
import glob
import random
import tensorflow as tf
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
def get_label_to_index():
with open('color_to_index.txt', 'r') as f:
color_to_index = dict((line.split(' ')[0], int(line.split(' ')[1])) for line in f.readlines())
index_to_color = dict((v, k) for (k, v) in color_to_index.items())
with open('class_to_index.txt', 'r') as f:
class_to_index = dict((line.split(' ')[0], int(line.split(' ')[1])) for line in f.readlines())
index_to_class = dict((v, k) for k, v in class_to_index.items())
return color_to_index, index_to_color, class_to_index, index_to_class
def load_images(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [224, 224])
image = tf.cast(image, tf.float32)
image = image / 255.0
image = image * 2 - 1
return image
def make_dataset():
co_to_index, _, cl_to_index, _ = get_label_to_index()
all_images_path = glob.glob('dataset/*/*')
random.shuffle(all_images_path)
all_colors_label = [co_to_index.get(i.split('\\')[1].split('_')[0]) for i in all_images_path]
all_classes_label = [cl_to_index.get(i.split('\\')[1].split('_')[1]) for i in all_images_path]
image_dataset = tf.data.Dataset.from_tensor_slices(all_images_path)
image_dataset = image_dataset.map(load_images, tf.data.experimental.AUTOTUNE)
label_dataset = tf.data.Dataset.from_tensor_slices((all_colors_label, all_classes_label))
dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
count = len(all_images_path)
test_count = int(count * 0.2)
train_count = count - test_count
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
BATCH_SIZE = 32
train_dataset = train_dataset.shuffle(train_count).repeat(-1)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
return train_dataset, test_dataset, train_count, test_count
if __name__ == '__main__':
train_data, test_data, train_ct, test_ct = make_dataset()
print(train_data)
print(test_data)
五、进行模型的构建
import tensorflow as tf
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
def make_model():
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False)
inputs = tf.keras.Input(shape=(224, 224, 3))
x = mobile_net(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x1 = tf.keras.layers.Dense(1024, activation='relu')(x)
output_color = tf.keras.layers.Dense(3, activation='softmax', name='output_color')(x1)
x2 = tf.keras.layers.Dense(1024, activation='relu')(x)
output_class = tf.keras.layers.Dense(4, activation='softmax', name='output_class')(x2)
model = tf.keras.Model(inputs=inputs,
outputs=[output_color, output_class])
model.summary()
return model
if __name__ == '__main__':
make_model()
六、进行模型的训练
import tensorflow as tf
import os
from data_loader import make_dataset
from model import make_model
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
train_dataset, test_dataset, train_count, test_count = make_dataset()
my_model = make_model()
my_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss={
'output_color': 'sparse_categorical_crossentropy',
'output_class': 'sparse_categorical_crossentropy'
},
metrics=['acc']
)
train_steps = train_count // 32
test_steps = test_count // 32
tf_tensorboard = tf.keras.callbacks.TensorBoard('logs', histogram_freq=1)
my_model.fit(train_dataset,
epochs=100,
steps_per_epoch=train_steps,
validation_data=test_dataset,
validation_steps=test_steps,
callbacks=[tf_tensorboard])
my_model.save(r'model_data/model.h5')
七、进行模型的预测
import numpy
import tensorflow as tf
import os
from data_loader import get_label_to_index
import matplotlib.pyplot as plt
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
my_model = tf.keras.models.load_model(r'model_data/model.h5')
color_to_index, index_to_color, class_to_index, index_to_class = get_label_to_index()
while True:
path = input('请输入图片路径:')
try:
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
except:
print('文件路径输入错误!')
continue
else:
inputs = tf.image.resize(image, [224, 224])
inputs = tf.cast(inputs, tf.float32)
inputs = inputs / 255.0
inputs = inputs * 2 - 1
inputs = tf.expand_dims(inputs, 0)
pre = my_model.predict(inputs)
plt.title(
'It is a ' + index_to_color.get(numpy.argmax(pre[0])) + ' ' + index_to_class.get(numpy.argmax(pre[1])))
plt.imshow(image)
plt.savefig('result.jpg')
plt.show()
八、预测结果的展示