TensorFlow自定义数据集需要继承的类:tf.keras.utils.Sequence
模板
class SiameseSequence(tf.keras.utils.Sequence):
def __init__(self):
pass
def __len__(self):
# 数据集长度
pass
def __getitem__(self, idx):
# 获取一批数据(1个batch)
pass
实际应用
import numpy as np
import tensorflow as tf
import glob
# load_img 加载图像 img_to_array 数据格式转换
from tensorflow.keras.preprocessing import image
# 对 resnet 图像预处理
from tensorflow.keras.applications.vgg16 import preprocess_input
import random
# 训练集
train_classes = glob.glob('dogImages/train/*')
li = []
for i in train_classes:
# print(i)
imgs = glob.glob(i + '/*.jpg')
# print(imgs)
for j in imgs:
# print(j)
li.append(j)
print(len(train_classes))
print(len(li))
class SiameseSequence(tf.keras.utils.Sequence):
def __init__(self, data_list, file_list, batch_size):
# 所有文件夹完整路径
self.data_list = data_list
# 类别个数
self.file_list = file_list
# 一组几个
self.batch_size = batch_size
def __len__(self):
# 数据集总长度 / batch_size
# 计算有多少个batch_size
num = (len(self.data_list) / self.batch_size) + 1
return num
def preprocess_image(self, filename):
# 预处理图片
img = image.load_img(filename, target_size=(224, 224))
img = image.img_to_array(img)
img = preprocess_input(img)
return img
def __getitem__(self, idx):
# 获取一批数据(1个batch)
batch_A = []
for i in range(self.batch_size):
batch_A.append(self.preprocess_image(self.data_list[idx]))
return np.array(batch_A)
traingen = SiameseSequence(li, len(train_classes), 32)
print(traingen[0].shape)
import cv2
# 预处理 解析(这里是归一化)
def changeImg(img):
max_v = img.max()
min_v = img.min()
img = (img - min_v) / max_v
return img[0]
def cv_show(neme, img):
# cv2.namedWindow(neme, cv2.WINDOW_NORMAL)
cv2.imshow(neme, img)
cv2.waitKey(0)
cv2.destroyAllWindows()
# traingen[每个数据]
img = changeImg(traingen[1])
print(img.shape)
cv_show('neme', img)