从零开始开发自己的类keras深度学习框架1 :实现数据载入模块

认真学习,佛系更博。

本章将介绍如何实现数据的载入功能,该模块和其他模块较为独立,也最容易实现;因此本章内容只做简单介绍,不做过多说明;

我们拿mnist作为示例,目标是将mnist数据转化为神经网络模型可以处理的数据。首先获取mnist原始图片数据,并保存在本地,获取mnist图片的方法见我另一篇博客:深度学习系列:从mnist数据集中提取mnist图片

拿到图片数据后我们将其放入一个文件夹中,比如我放在项目的dataset目录内,下面有10个子目录,分别对应0-9十种类型的图片,接下来将实现数据处理的功能;

先建立一个子模块我取名为enet,然后在enet下新建子模块,取名为data,然后在data下新建python文件,取名为image_controller.py,此时data下有两个文件:__init__.py和image_controller.py,不要删除__init__.py文件,后面封装库需要用到;

编辑image_controller.py文件,新建一个类:ImageHandler,其初始化需要4个参数:data_dir、gray、use_scale、flatten,其含义见代码注释:

    def __init__(self, data_dir, gray=False, use_scale=False, flatten=False):
        """
        :param data_dir: 根目录
        :param gray: 是否以灰度化格式读入图片
        :param use_scale: 是否/255.
        :param flatten: 数据结果是否需要拉伸为1维
        """
        self.data_dir = data_dir
        self.gray = gray
        self.use_scale = use_scale
        self.flatten = flatten

        self.class_dict = dict()

class_dict用于获取每个目录对应的类别,比如有两个子目录,目录名分别为“猫”和“狗”,则class_dict为{0: "猫", 1: "狗"}

然后我们需要定义一个函数获取数据,命名为get_data,获取数据的主要步骤很简单,也很流程化,首先获取子目录,然后对子目录的每一个图片,使用opencv读取图像数据,根据参数确定数据的样式,最后返回。其代码为:

        sub_dir_list = [sub_dir for sub_dir in os.listdir(self.data_dir) if
                        os.path.isdir(os.path.join(self.data_dir, sub_dir))]
        sub_dir_list.sort()

        train_data = []
        train_label = []
        # 遍历文件夹
        for dir_index, sub_dir in enumerate(sub_dir_list):

            for sub_file in os.listdir(os.path.join(self.data_dir, sub_dir)):
                if self.gray:
                    image = cv2.imread(os.path.join(self.data_dir, sub_dir, sub_file), cv2.IMREAD_GRAYSCALE)
                else:
                    image = cv2.imread(os.path.join(self.data_dir, sub_dir, sub_file))

                image = np.array(image)

                # 灰度模式读取为2维数据,需要添加1维通道信息
                if self.gray:
                    image = np.expand_dims(image, axis=-1)

                # 如果使用flatten,则应该拉成向量
                if self.flatten:
                    image = image.flatten()

                if self.use_scale:
                    image = image / 255.

                # 插入到结果集中
                train_data.append(image)
                train_label.append(dir_index)

        # train_label = np.eye(len(sub_dir_list))[train_label]
        train_label = convert_one_hot(train_label, len(sub_dir_list))

        return train_data, train_label

convert_one_hot函数将标签转化为one_hot向量,其代码为:

def convert_one_hot(input_signal, size):
    """
    将输入数据转化称one_hot数据
    :param input_signal: 输入数据
    :param size: 样本类别
    :return:
    """
    return np.eye(size)[np.asarray(input_signal)]

这里需要了解一下np.eye的用法;

另外还需要获取class_dict的功能,实现为get_class_dict,代码如下:

    def get_class_dict(self):
        """
        获取类别字典
        :return:
        """

        self.class_dict.clear()

        # 读取数据,去掉缓存文件
        sub_dir_list = [sub_dir for sub_dir in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir,
                                                                                                       sub_dir))]
        sub_dir_list.sort()

        for index, sub_dir in enumerate(sub_dir_list):
            self.class_dict[index] = sub_dir

        return self.class_dict

过程很简单,需要注意的是获取子目录的时候要排除非目录文件,因为我们添加了cache功能用于加速数据的读取,可能会生成cache文件,完整的image_handler.py代码如下:

import os
import cv2
import numpy as np
import pickle

from enet.utils.util import train_test_split, convert_one_hot


class ImageHandler(object):
    """
    数据集载入控制类,传入参数为目录,并且目录满足以下特征:
    1. 目录下所有文件都是两级结构,比如 class1/file1.png
    2. 所有文件都为图片类型,可被计算机读取
    3. 同一类图片需要放在同一个目录下

    在生成数据的同时会解析标签字典,其过程为:
    1. 将1级目录排序
    2. 按照顺序生成标签,从0开始
    3. 记录标签到目录名的字典返回class_dict
    """

    def __init__(self, data_dir, gray=False, use_scale=False, flatten=False):
        """
        :param data_dir: 根目录
        :param gray: 是否以灰度化格式读入图片
        :param use_scale: 是否/255.
        :param flatten: 数据结果是否需要拉伸为1维
        """
        self.data_dir = data_dir
        self.gray = gray
        self.use_scale = use_scale
        self.flatten = flatten

        self.class_dict = dict()

    def get_data(self, ratio=0.3, read_cache=True):
        """
        读取数据
        :param ratio: 数据拆分比例
        :param read_cache: 是否使用cache读入,若使用,则直接使用缓存数据,上述参数可能无效
        :return:
        """

        # 这里引入pickle加速读取
        if read_cache:
            if os.path.exists(os.path.join(self.data_dir, "data_cache.pkl")):
                with open(os.path.join(self.data_dir, "data_cache.pkl"), "rb") as reader:
                    return pickle.load(reader)

        train_data, train_label = self.load_data()
        result = train_test_split(train_data, train_label, ratio)

        # 保存到cache文件中
        with open(os.path.join(self.data_dir, "data_cache.pkl"), "wb") as writer:
            pickle.dump(result, writer)

        return result

    def load_data(self):
        """加载图片数据, 返回数据和标签"""
        sub_dir_list = [sub_dir for sub_dir in os.listdir(self.data_dir) if
                        os.path.isdir(os.path.join(self.data_dir, sub_dir))]
        sub_dir_list.sort()

        train_data = []
        train_label = []
        # 遍历文件夹
        for dir_index, sub_dir in enumerate(sub_dir_list):

            for sub_file in os.listdir(os.path.join(self.data_dir, sub_dir)):
                if self.gray:
                    image = cv2.imread(os.path.join(self.data_dir, sub_dir, sub_file), cv2.IMREAD_GRAYSCALE)
                else:
                    image = cv2.imread(os.path.join(self.data_dir, sub_dir, sub_file))

                image = np.array(image)

                # 灰度模式读取为2维数据,需要添加1维通道信息
                if self.gray:
                    image = np.expand_dims(image, axis=-1)

                # 如果使用flatten,则应该拉成向量
                if self.flatten:
                    image = image.flatten()

                if self.use_scale:
                    image = image / 255.

                # 插入到结果集中
                train_data.append(image)
                train_label.append(dir_index)

        # train_label = np.eye(len(sub_dir_list))[train_label]
        train_label = convert_one_hot(train_label, len(sub_dir_list))

        return train_data, train_label

    def get_class_dict(self):
        """
        获取类别字典
        :return:
        """

        self.class_dict.clear()

        # 读取数据,去掉缓存文件
        sub_dir_list = [sub_dir for sub_dir in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir,
                                                                                                       sub_dir))]
        sub_dir_list.sort()

        for index, sub_dir in enumerate(sub_dir_list):
            self.class_dict[index] = sub_dir

        return self.class_dict

如注释所写,该模块对数据目录结构有要求,必须为2级目录,每类图片放在相同的子目录内;

最后修改__init__.py内容:

from enet.data.image_controller import ImageHandler

修改后就可以在外部使用from enet.data import ImageHandler调用该类的所有功能;

到这里,该模块内容已经实现完毕,不过功能较单一,只能实现对图片类数据读取。可以尝试写其他类型数据的读取功能,也可以尝试调用该模块看看返回数据的格式、内容等;

整个项目代码见github:https://github.com/darkwhale/neural_network ;

下一章将详细介绍网络层的实现及原理;

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值