tensorflow导入自己的数据集

在构建tensorflow模型过程中,可谓是曲折颇多,一些教程上教会了我们如何使用下载的现成数据集,但却没有提及如何构建自己的数据集。我自己在学习过程中也走了不少弯路,希望这一系列的博客能解决大家的一些困惑。

我们本地构建数据集主要是以下几个步骤

1.数据处理

2.数据增强 

3.数据导入

4.构建模型

5.训练模型

这篇先讲一下数据处理的一些操作,后面的步骤会慢慢发出来。

1.导入第三方库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import math
import pathlib
import random
import matplotlib.pyplot as plt
import numpy as np

这里会注意到,我在导入os库时,在后面加了

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

这句话的作用是避免报错:This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)

2.导入数据路径

data_root = pathlib.Path('./image')
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]

我这里的./image是我本地图片集所在的文件夹,image文件夹下是两个分别保存不同种类图片的文件夹,因为我这里是做二分类,所以只有两个不同种类的文件夹,如果大家需要构建识别多种图片的模型,可以添加其他文件夹。

 3.随机打乱图片,这一步的目的是为了让图片集去特殊化,提高模型的准确率,因为如果你的图片中有比较相近的,而且数量比较多,会影响模型的学习。这一步是调用了random的shuffle,传入图片集列表,随机打乱。

random.shuffle(all_image_paths)

 4.构建标签及索引

其实是构建了一个字典

#列出标签
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
#为标签分配索引
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#创建列表,存放标签和索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]

5.加载和格式化图片

我们可以看到,tf.image.decode_jpeg(image,channels=3,这句话的作用是把图片变成三通道图,即RGB式图片。需要强调一下,tf.image.resize()这个小东西好用的很,可以把你的图片统一大小,这在后面我们训练模型是必须的,统一大小的图片更有利于我们的模型学习。而image/255.0是为了使图像进行归一化,得到的数值范围为[0, 1],彩色图片会变成灰图。

load_and_prepro_image()这个函数就是读取传入路径的图片集,然后返回值是经过了preprocess_image 这个函数的调用,将返回的图片处理为灰度图,比较简单暴力。

#加载和格式化图片
def preprocess_image(image):
    image = tf.image.decode_jpeg(image,channels=3)
    image = tf.image.resize(image,[192,192])
    image /= 255.0
    return image

def load_and_prepro_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)
for i in range(len(all_image_paths)):
    image_path = all_image_paths[i]
    label = all_image_labels[i]
    plt.imshow(load_and_prepro_image(image_path))
    plt.grid(False)
    plt.xlabel(image_path)
    plt.title(label_names[label].title())
    #plt.show()

然后关于这个for循环,其实不是必须的,只是为了方便我们检查图片的处理效果,调用的库是matplotlib,python比较有名的绘图库。

就先到这,后会有期。

下面是全部源码,tensorflow版本是2.5,py版本3.7,cuda11.6。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import math
import pathlib
import random
import matplotlib.pyplot as plt
import numpy as np

#数据处理
#导入数据路径
data_root = pathlib.Path('./image')
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
#随机打乱图片
random.shuffle(all_image_paths)
#列出标签
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
#为标签分配索引
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#创建列表,存放标签和索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]
#加载和格式化图片
def preprocess_image(image):
    image = tf.image.decode_jpeg(image,channels=3)
    image = tf.image.resize(image,[192,192])
    image = image/255.0
    return image

def load_and_prepro_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)
for i in range(len(all_image_paths)):
    image_path = all_image_paths[i]
    label = all_image_labels[i]
    plt.imshow(load_and_prepro_image(image_path))
    plt.grid(False)
    plt.xlabel(image_path)
    plt.title(label_names[label].title())
    #plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

冯简

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值