ImageDataGenerator读取的数据集转Numpy array

ImageDataGenerator读取的数据集转Numpy array


常用的数据集类型可以分为两种:
(1)一种是网上的经典数据集,比如mnist,一般会给写好的读取方法,比如mnist.load_data(),读取出来的返回值是Numpy array;
(2)一种是自己本地的数据集,路径下每个文件夹代表一类图像,目录结构类似于

data
--type1
----img1-1
----img1-2
--type2
----img2-2
----img2-2
--type3
----img3-2
----img3-2

这时就要用keras的ImageDataGenerator生成flow来读取数据。


使用shap做可解释机器学习时发现原作者给的例子里用的数据集是mnist(链接:https://github.com/slundberg/shap),计算相关的代码如下

background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]

# explain predictions of the model on four images
e = shap.DeepExplainer(model, background)
# ...or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[1:5])

# plot the feature attributions
shap.image_plot(shap_values, -x_test[1:5])

这里DeepExplainer、shap_values和image_plot的参数都要求是Numpy array,但是用ImageDataGenerator生成的数据集不是这种格式的,就需要进行转换。作者没有给相关的例子,就自己写一个,代码效率有点低,不知道有没有更好的写法,欢迎大家补充。

test_dir = './data' #数据集路径,路径下每个文件夹代表一类图像
test_pic_gen = ImageDataGenerator(rescale = 1./255) #图像预处理,和训练好的模型保持一致,通常是缩放为1/255以归一化
test_flow = test_pic_gen.flow_from_directory(test_dir, (224, 224), batch_size = 1, class_mode = 'categorical') 
	#注意这里batch_size最好设为1,方便读取;设成其他值(比如比较常用的8)需要额外进行遍历

def imgFlow2npArray(img_flow, img_sum, img_size):
    x = np.zeros(shape = (img_sum, img_size[0], img_size[1], img_size[2]))
    y = np.zeros(shape = (img_sum))
    for image in test_flow:
        img_sum = img_sum - 1
        x[img_sum] = image[0][0] #image[0]以矩阵形式保存图像数据,需要去除多余的一个维度
        y[img_sum] = image[1][0].tolist().index(1.) #image[1]保存图像对应类别,同样需要去除多余的一个维度
        if img_sum <= 0:
            break
    return x, y

x_test, y_test = imgFlow2npArray(test_flow, 5, (224, 224, 3)) #此时的x_test和y_test同(x_train, y_train), (x_test, y_test) = mnist.load_data()中的x_test和y_test

这里x_test, y_test就已经是Numpy array格式了,可以顺利跑通shap代码。


但是这里也出现了一个奇怪的问题:图表中最左边一列应该对应原始图像的位置显示为纯黑图像,暂时不知道是为啥导致的,以后找到解决方法了再补充。

跑原始shap代码的显示结果:
跑原始shap代码的显示结果
跑自己的代码的显示结果:

跑自己的代码的显示结果


2021/6/21更新:

问题解决了,自己用的数据集图像是24位RGB图,改成8位灰度图就行了,代码如下:

# shap.image_plot(shap_values, x_test[0:]) 改成下面两行代码
x_test_gray = x_test[:, :, :, :1]
shap.image_plot(shap_values, x_test_gray[0:])

结果图:
在这里插入图片描述
但是作者的教程里也有用到显示RGB图的例子,代码里没看到有什么特殊处理,不知道他是怎么做到的,以后找到办法再更新。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值