CNN分类中批量读取数据及制作标签时报错:could not broadcast input array from shape (128,128,3) into shape (128,128)

一、背景

最近在做CNN分类时,用了一个能够一次性读取文件夹下所有数据并按文件夹制作相应标签的函数。之前做13类别,总计约1000张图片时,用这个函数没有任何问题,但后后来加到约40000张图片的时候,运行就报错。

二、问题描述

用CNN做多类别分类时,往往需要大量数据,由于我自己是没有这么多数据的,所以很多数据需要从网上爬,总计约4万张图,进行CNN分类实验。

进行完数据预处理之后,关键的一个步骤就是需要将图像和标签做好传入到网络中,我使用的函数为:

import os
import glob
from skimage import io, transform
import numpy as np

width, height = 128, 128

# 定义读取图片的函数, 并将其resize成width*height尺寸大小
def read_img(image_path):
    cate = [image_path+f for f in os.listdir(image_path) if os.path.isdir(image_path+f)]
    imgs = []
    labels = []
    for idx, folder in enumerate(cate):
        for im in glob.glob(folder+'/*.jpg'):
            print('reading the images:%s' % im)
            img = io.imread(im)
            img = transform.resize(img, (width, height))
            imgs.append(img)
    return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)

调用方法:

data, label = read_img('./data/')

实现的功能是读取'data/'文件夹下的所有子文件夹及其相应的数据,按文件夹制作标签。第一次图片比较少的时候,没有任何问题,一次性读取成功,然后将数据传入到了CNN网络中。

然而,第二次增加了很多图片,图片质量参差不齐,甚至有的图片是4维的(我已经删掉了),运行该程序后直接报错:

大致意思就是说,不能把(128,128,3)的图像变成(128,128)。

三、解决方法

网上遇到这个问题的人还蛮多的,看了一下,普遍认为最简单的解决办法就是把RGB变成灰度图进行处理:

来源:https://github.com/carpedm20/DCGAN-tensorflow/issues/162

当然我们需要的是处理彩色影像。那么首先需要搞清楚这个问题的原因。有人提到:

来源:https://github.com/carpedm20/DCGAN-tensorflow/issues/162

所以,这个问题的原因无非就是:大量众多的彩色图中混有个别灰度图,导致channel的数量不统一,自然无法进行broadcast。具体出现问题的细致讲解可以看这里:https://stackoverflow.com/questions/43977463/valueerror-could-not-broadcast-input-array-from-shape-224-224-3-into-shape-2

下面谈一下解决方法,因为原因是出在原始数据上的,所以我们只要对原始数据进行过滤,删选掉不符合要求的图像就可以了,根据这个思路,可以对原代码进行修改,增加一下判断步骤:

# 定义读取图片的函数, 并将其resize成width*height尺寸大小
def read_img(image_path):
    cate = [path+f for f in os.listdir(image_path) if os.path.isdir(image_path+f)]
    imgs = []
    labels = []
    for idx, folder in enumerate(cate):
        for im in glob.glob(folder+'/*.jpg'):
            print('reading the images:%s' % im)
            img = io.imread(im)
            try:
                if img.shape[2] == 3:
                    img = transform.resize(img, (width, height))
                    imgs.append(img)
                    labels.append(idx)
            except:
                continue
    return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)

最终的结果:

可以看到,数据已经能顺利读取,并将数据和标签传入到了CNN中了。

 

  • 9
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

全部梭哈迟早暴富

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值