ImageDataGenerator读取的数据集转Numpy array
常用的数据集类型可以分为两种:
(1)一种是网上的经典数据集,比如mnist,一般会给写好的读取方法,比如mnist.load_data(),读取出来的返回值是Numpy array;
(2)一种是自己本地的数据集,路径下每个文件夹代表一类图像,目录结构类似于
data
--type1
----img1-1
----img1-2
--type2
----img2-2
----img2-2
--type3
----img3-2
----img3-2
这时就要用keras的ImageDataGenerator生成flow来读取数据。
使用shap做可解释机器学习时发现原作者给的例子里用的数据集是mnist(链接:https://github.com/slundberg/shap),计算相关的代码如下
background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]
# explain predictions of the model on four images
e = shap.DeepExplainer(model, background)
# ...or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[1:5])
# plot the feature attributions
shap.image_plot(shap_values, -x_test[1:5])
这里DeepExplainer、shap_values和image_plot的参数都要求是Numpy array,但是用ImageDataGenerator生成的数据集不是这种格式的,就需要进行转换。作者没有给相关的例子,就自己写一个,代码效率有点低,不知道有没有更好的写法,欢迎大家补充。
test_dir = './data' #数据集路径,路径下每个文件夹代表一类图像
test_pic_gen = ImageDataGenerator(rescale = 1./255) #图像预处理,和训练好的模型保持一致,通常是缩放为1/255以归一化
test_flow = test_pic_gen.flow_from_directory(test_dir, (224, 224), batch_size = 1, class_mode = 'categorical')
#注意这里batch_size最好设为1,方便读取;设成其他值(比如比较常用的8)需要额外进行遍历
def imgFlow2npArray(img_flow, img_sum, img_size):
x = np.zeros(shape = (img_sum, img_size[0], img_size[1], img_size[2]))
y = np.zeros(shape = (img_sum))
for image in test_flow:
img_sum = img_sum - 1
x[img_sum] = image[0][0] #image[0]以矩阵形式保存图像数据,需要去除多余的一个维度
y[img_sum] = image[1][0].tolist().index(1.) #image[1]保存图像对应类别,同样需要去除多余的一个维度
if img_sum <= 0:
break
return x, y
x_test, y_test = imgFlow2npArray(test_flow, 5, (224, 224, 3)) #此时的x_test和y_test同(x_train, y_train), (x_test, y_test) = mnist.load_data()中的x_test和y_test
这里x_test, y_test就已经是Numpy array格式了,可以顺利跑通shap代码。
但是这里也出现了一个奇怪的问题:图表中最左边一列应该对应原始图像的位置显示为纯黑图像,暂时不知道是为啥导致的,以后找到解决方法了再补充。
跑原始shap代码的显示结果:
跑自己的代码的显示结果:
2021/6/21更新:
问题解决了,自己用的数据集图像是24位RGB图,改成8位灰度图就行了,代码如下:
# shap.image_plot(shap_values, x_test[0:]) 改成下面两行代码
x_test_gray = x_test[:, :, :, :1]
shap.image_plot(shap_values, x_test_gray[0:])
结果图:
但是作者的教程里也有用到显示RGB图的例子,代码里没看到有什么特殊处理,不知道他是怎么做到的,以后找到办法再更新。