ImageDataGenerator读取的数据集转Numpy array

ImageDataGenerator读取的数据集转Numpy array


常用的数据集类型可以分为两种:
(1)一种是网上的经典数据集,比如mnist,一般会给写好的读取方法,比如mnist.load_data(),读取出来的返回值是Numpy array;
(2)一种是自己本地的数据集,路径下每个文件夹代表一类图像,目录结构类似于

data
--type1
----img1-1
----img1-2
--type2
----img2-2
----img2-2
--type3
----img3-2
----img3-2

这时就要用keras的ImageDataGenerator生成flow来读取数据。


使用shap做可解释机器学习时发现原作者给的例子里用的数据集是mnist(链接:https://github.com/slundberg/shap),计算相关的代码如下

background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]

# explain predictions of the model on four images
e = shap.DeepExplainer(model, background)
# ...or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[1:5])

# plot the feature attributions
shap.image_plot(shap_values, -x_test[1:5])

这里DeepExplainer、shap_values和image_plot的参数都要求是Numpy array,但是用ImageDataGenerator生成的数据集不是这种格式的,就需要进行转换。作者没有给相关的例子,就自己写一个,代码效率有点低,不知道有没有更好的写法,欢迎大家补充。

test_dir = './data' #数据集路径,路径下每个文件夹代表一类图像
test_pic_gen = ImageDataGenerator(rescale = 1./255) #图像预处理,和训练好的模型保持一致,通常是缩放为1/255以归一化
test_flow = test_pic_gen.flow_from_directory(test_dir, (224, 224), batch_size = 1, class_mode = 'categorical') 
	#注意这里batch_size最好设为1,方便读取;设成其他值(比如比较常用的8)需要额外进行遍历

def imgFlow2npArray(img_flow, img_sum, img_size):
    x = np.zeros(shape = (img_sum, img_size[0], img_size[1], img_size[2]))
    y = np.zeros(shape = (img_sum))
    for image in test_flow:
        img_sum = img_sum - 1
        x[img_sum] = image[0][0] #image[0]以矩阵形式保存图像数据,需要去除多余的一个维度
        y[img_sum] = image[1][0].tolist().index(1.) #image[1]保存图像对应类别,同样需要去除多余的一个维度
        if img_sum <= 0:
            break
    return x, y

x_test, y_test = imgFlow2npArray(test_flow, 5, (224, 224, 3)) #此时的x_test和y_test同(x_train, y_train), (x_test, y_test) = mnist.load_data()中的x_test和y_test

这里x_test, y_test就已经是Numpy array格式了,可以顺利跑通shap代码。


但是这里也出现了一个奇怪的问题:图表中最左边一列应该对应原始图像的位置显示为纯黑图像,暂时不知道是为啥导致的,以后找到解决方法了再补充。

跑原始shap代码的显示结果:
跑原始shap代码的显示结果
跑自己的代码的显示结果:

跑自己的代码的显示结果


2021/6/21更新:

问题解决了,自己用的数据集图像是24位RGB图,改成8位灰度图就行了,代码如下:

# shap.image_plot(shap_values, x_test[0:]) 改成下面两行代码
x_test_gray = x_test[:, :, :, :1]
shap.image_plot(shap_values, x_test_gray[0:])

结果图:
在这里插入图片描述
但是作者的教程里也有用到显示RGB图的例子,代码里没看到有什么特殊处理,不知道他是怎么做到的,以后找到办法再更新。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
可以使用 `ImageDataGenerator` 类从 COCO 数据集读取图像并进行数据增强。对于 COCO 数据集中的人体关键点,我们需要使用 COCO API 获取关键点的位置信息,然后将其与图像一起传递给 `ImageDataGenerator` 进行增强。 以下是一个示例代码,展示了如何使用 `ImageDataGenerator` 对 COCO 数据集中的 person keypoints 进行数据增强: ```python import cv2 import numpy as np from pycocotools.coco import COCO from keras.preprocessing.image import ImageDataGenerator # 数据集路径和注释文件路径 dataDir = 'path to dataset' dataType = 'train2017' annFile = '{}/annotations/person_keypoints_{}.json'.format(dataDir, dataType) # 加载 COCO 数据集 coco = COCO(annFile) # 定义 ImageDataGenerator data_gen = ImageDataGenerator( rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, zoom_range=0.2, horizontal_flip=True, vertical_flip=True, fill_mode='nearest') # 获取 COCO 数据集中的所有 person keypoints 的 ID catIds = coco.getCatIds(catNms=['person']) imgIds = coco.getImgIds(catIds=catIds) # 遍历所有图像 for imgId in imgIds: # 从 COCO API 中获取图像的注释 annIds = coco.getAnnIds(imgIds=imgId, catIds=catIds, iscrowd=None) anns = coco.loadAnns(annIds) # 从 COCO API 中获取图像的信息 img = coco.loadImgs(imgId)[0] img_path = '{}/images/{}/{}'.format(dataDir, dataType, img['file_name']) # 读取图像 img_data = cv2.imread(img_path) # 将 person keypoints 的位置信息提取出来 keypoint_coords = [] for ann in anns: keypoints = ann['keypoints'] for i in range(0, len(keypoints), 3): x, y, v = keypoints[i:i+3] if v > 0: keypoint_coords.append((x, y)) # 将图像和关键点坐标传递给 ImageDataGenerator 进行增强 keypoint_coords = np.array(keypoint_coords).reshape((-1, 2)) gen_data = data_gen.flow( np.expand_dims(img_data, axis=0), np.expand_dims(keypoint_coords, axis=0)) # 保存增强后的图像和关键点坐标 for i in range(10): # 生成 10 个样本 gen_img_data, gen_keypoint_coords = next(gen_data) gen_img = gen_img_data[0].astype(np.uint8) gen_keypoint_coords = gen_keypoint_coords[0].reshape((-1,)) for j in range(0, len(gen_keypoint_coords), 2): x, y = gen_keypoint_coords[j:j+2].astype(np.int32) cv2.circle(gen_img, (x, y), 3, (0, 255, 0), -1) cv2.imwrite('gen_{}_{}.jpg'.format(imgId, i), gen_img) ``` 在上述代码中,我们首先加载了 COCO 数据集,并使用 `getCatsIds` 方法获取了包含 person keypoints 的所有注释的 ID。然后,我们遍历了所有图像,从 COCO API 中获取了每个图像的注释和信息,并将其传递给 `ImageDataGenerator` 进行数据增强。最后,我们将增强后的图像和关键点坐标保存到磁盘上。 注意,在这个示例中,我们只进行了简单的数据增强,你可以根据具体情况调整 `ImageDataGenerator` 的参数。此外,由于人体关键点的数量固定为 17 个,因此我们使用了一个固定大小的数组来存储关键点的位置信息。如果你想使用 COCO 数据集中的其他关键点,需要相应地修改代码。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值