TensorFlow训练分类自己的图片数据

16 篇文章 5 订阅
5 篇文章 0 订阅

TensorFlow2.4

import tensorflow as tf
import pathlib
import random
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
#数据集文件夹
data_root=pathlib.Path('/Users/thrive/Library/Mobile Documents/com~apple~CloudDocs/发文章/code/dataset/classify/matlab/wavelettf+/figs')
#打印文件夹名称
for item in data_root.iterdir():
  print(item)

结果:
/Users/thrive/Library/Mobile Documents/comappleCloudDocs/发文章/code/dataset/classify/matlab/wavelettf+/figs/.DS_Store
/Users/thrive/Library/Mobile Documents/comappleCloudDocs/发文章/code/dataset/classify/matlab/wavelettf+/figs/0
/Users/thrive/Library/Mobile Documents/comappleCloudDocs/发文章/code/dataset/classify/matlab/wavelettf+/figs/1

#获取图片路径
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
print(image_count)
print(all_image_paths[0])
#获取图片大小
imagetest = tf.io.read_file(all_image_paths[0])
imagetest = tf.image.decode_jpeg(imagetest, channels=3)
print(imagetest.shape)
#获取标签
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
print('label names: ',label_names)
#为标签分配索引
label_to_index = dict((name, index) for index, name in enumerate(label_names))
print(label_to_index)
#创建一个列表,包含每个文件的标签索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]
print("First 10 labels indices: ", all_image_labels[:10])

结果:
100
/Users/thrive/Library/Mobile Documents/comappleCloudDocs/发文章/code/dataset/classify/matlab/wavelettf+/figs/0/0_81.jpg
(556, 556, 3)
label names: [‘0’, ‘1’]
{‘0’: 0, ‘1’: 1}
First 10 labels indices: [0, 1, 0, 1, 1, 0, 1, 0, 1, 0]

#创建数据集
batch_size = 32
img_height = imagetest.shape[0]
img_width = imagetest.shape[1]

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_root,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_root,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

class_names = train_ds.class_names
print('class names: ',class_names)

结果:
Found 100 files belonging to 2 classes.
Using 80 files for training.
Found 100 files belonging to 2 classes.
Using 20 files for validation.
class names: [‘0’, ‘1’]

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(10).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
# set net
model = tf.keras.models.Sequential([
  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(64, activation='relu'),
  layers.Dense(2)
])

model.summary()

在这里插入图片描述

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

epochs=10
history=model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    )

结果:
Epoch 10/10
3/3 [==============================] - 5s 2s/step - loss: 0.6931 - accuracy: 0.3958 - val_loss: 0.6931 - val_accuracy: 0.2500

准确率比随机还低一半,哭了

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

油泼西红柿

Wish U Thrive

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值