MindRecord之FileWriter类

在开始之前,首先声明本篇文章参考官方文档编程指南,我基于官网的文章加以自己的理解发表了这篇博客,希望大家能够更快更简单直观的体验MindSpore,如有不妥的地方欢迎大家指正。

希望大家为我点个赞,码字不易啊。

【本文代码编译环境为MindSpore1.3.0 CPU版本】

经过手写数字识别初体验的介绍,我想大家对于mindspore文件夹的各个模块的功能已经有了大概的了解。在后续的文章中,我将按照训练一个神经网络的步骤,从数据集加载开始直到模型的成功验证推理,逐个地摸索每个模块的功能。在本篇文章中,我们来看一下MindSpore自定义数据格式MindRecord,它的一些主要的方法在mindspore.mindrecord模块。使用MindRecord格式的数据去训练网络,可以获得更好的性能提升。

MindRecord具备的特征如下:

  1. 实现多变的用户数据统一存储、访问,训练数据读取更加简便。
  2. 数据聚合存储,高效读取,且方便管理、移动。
  3. 高效的数据编解码操作,对用户透明、无感知。
  4. 可以灵活控制分区的大小,实现分布式训练。

我们使用MindRecord的目标是归一化提供训练测试所用的数据集,并通过dataset模块的相关方法进行数据的读取,将这些高效的数据投入训练。

data-conversion-concept

在mindspore.mindrecord模块中定义了一个专门的类FileWriter将用户定义的原始数据写入MindRecord文件。这个类的源码和相关方法的功能将是本篇文章的重要内容。

个人理解

首先,我认为FileWriter类是一个统筹者,它的功能是将原始数据文件写入MindRecord文件,但有一些基本的功能,如打开一个MindRecord文件,将数据写入MindRecord文件,关闭MindRecord文件。它们很重要,但他们并不是FileWriter类实现的,有三个专门的类来实现这些功能。一个是ShardHeader类:在C++模块中代表SHDDHead类的包装类。该类将存储MindRecord文件的元数据。另一个是SharedWriter类:在C++模块中代表SARDWRITE类的包装类。该类将将验证处理好的数据写入MindRecord文件。还有一个是ShardReader类:在C++模块中代表SARDRADER类的包装类。该类可以打开或关闭MindRecord文件,并且可以从MindRecord文件系列中读取一批数据。这三个类相当于是三个具有特定功能的小兵,而FileWriter类则可以看成一个厉害的将军,可以合理的调配三个拥有特定能力的小兵,以实现将原始数据写入MindRecord文件的任务。

其次,我们来看一下如何通过FileWriter类的调度来实现数据写入MindRecord文件的操作。首先,我们的实例化创建一个FileWriter对象,在对象的初始化过程中,我们会声明好一个SharedHeader对象和一个ShareWriter对象。然后我们调用FileWriter对象的add_schema方法,这里我告诉你这个add_schema方法是用来声明数据格式模板的,但实际上FileWriter类中并没有实现这个方法的具体功能,只是提供或处理了一些参数,这个时候它就调用了初始化过程中声明的SharedHeader对象的add_schema方法,同理,我们将数据写入MindRecord文件要调用FileWriter对象的write_raw_data方法,FileWriter类中这个方法,实际上也只提供了参数而没有实现具体功能,我们需要调用SharedWriter类的write_raw_data方法。ShareReader类中的方法同样会在FileWriter类中的某些方法中被调用。

这样做的的话,很多的体现了代码的封装性。它既减少了FileWriter类的代码量,也更好的文件读取,写入等功能封装在各个类中,对于项目具有很强的友好性。

FileWriter主要方法介绍

class FileWriter:
  • 类的初始化方法def init(self, file_name, shard_num=1):

参数:file_name:数据格式:str(字符串) 含义:原始数据要写入MindRecord文件夹的名字

​ shard_num:数据格式:int(整型) 含义:生成MindRecord文件的个数,取值范围为[1, 1000],默认值为1

源码

    def __init__(self, file_name, shard_num=1):
        #检查路径中的文件名(file_name)是否可用
        check_filename(file_name)
        self._file_name = file_name
		
        #检查shard_num是否合法,包括是否为空值,是否为int类型,是否处在取值范围内
        if shard_num is not None:
            if isinstance(shard_num, int):
                if shard_num < MIN_SHARD_COUNT or shard_num > MAX_SHARD_COUNT:
                    raise ParamValueError("Shard number should between {} and {}."
                                          .format(MIN_SHARD_COUNT, MAX_SHARD_COUNT))
            else:
                raise ParamValueError("Shard num is illegal.")
        else:
            raise ParamValueError("Shard num is illegal.")

        self._shard_num = shard_num
        self._index_generator = True
        suffix_shard_size = len(str(self._shard_num - 1))

        if self._shard_num == 1:
            self._paths = [self._file_name]
        else:
            self._paths = ["{}{}".format(self._file_name,
                                         str(x).rjust(suffix_shard_size, '0'))
                           for x in range(self._shard_num)]

        self._append = False
        self._header = ShardHeader()
        self._writer = ShardWriter()
        self._generator = None

注意:一旦初始化方法生效之后,尽量不要更改生成的MindRecord文件的文件名,因为文件名更改可能造成文件无法读取。

  • 类方法:def open_for_append(cls, file_name):

参数:file_name:数据格式:str(字符串) 含义:MindRecord文件夹名称

功能:打开MindRecord文件,准备添加数据

返回值:打开的MindRecord文件的file writer对象

异常处理(raise):

  1. ParamValueError(传入参数错误):如果文件名无效
  2. FileNameError:如果路径包含无效字符
  3. MRMOpenError:如果无法打开MindRecord文件。
  4. mrmopenfrappenderror:如果无法打开附加数据的文件。

源码

   @classmethod
    def open_for_append(cls, file_name):
        check_filename(file_name)
        # construct ShardHeader
        # ShardReader类:在C++模块中代表SARDRADER类的包装类。该类将
        # 从MindRecord文件系列中读取一批数据。
        # open方法:打开文件并准备读取MindRecord文件。
        # close方法:关闭MindRecord文件
        
        # ShardHeader类:在C++模块中代表SHDDHead类的包装类。该类将
        # 存储MindRecord文件的元数据。初始化时需传入参数header,默认为None
        # get_header方法:返回MindRecord文件的head
        
        reader = ShardReader()
        reader.open(file_name)
        header = ShardHeader(reader.get_header())
        reader.close()

        instance = cls("append")
        instance.init_append(file_name, header)
        return instance
	
     def init_append(self, file_name, header):
     #将默认参数_append(False)改为True,表示MindRecord文件打开成功
     #将self._header实例化为已打开MindRecord文件的对象
     self._append = True
     self._file_name = file_name
     self._header = header
     self._writer.open_for_append(file_name)

  • 类方法:def add_schema(self, content, desc=None):

**参数:**content:数据格式:dict(字典对象) 含义:模式内容的字典

​ desc: 数据格式:str(字符串) 含义:架构描述,默认值:None

**功能:**添加模式(字典的形式)来描述需要写入的原始数据的形式

**返回值:**一个int整型数,表示该schema对于的id

异常处理(raise):

  1. MRMInvalidSchemaError:如果所写的schema无效,可能时格式错误
  2. MRMBuildSchemaError:如果未能生成schema
  3. MRMAddSchemaError:如果未能田间schema

源码:

def add_schema(self, content, desc=None):
    	# 调用self._validate_schema判断schema是否合理,ret为Ture or False,error_msg
        # 为错误信息
        ret, error_msg = self._validate_schema(content)
        
        # 若ret为False,即schema不合理,raise异常处理——1
        if ret is False:
            raise MRMInvalidSchemaError(error_msg)
            
        # 调用self._header.build_schema以一个合理的原始的schema生成具体的schema对象,
        # 此处的schema变量是ShardSchema类的对象
        schema = self._header.build_schema(content, desc)
        
        # 调用ShardHeader类的add_schema方法,为ShareHeader添加一个schema对象,并返回它的id
        return self._header.add_schema(schema)
    
# 参数:content:数据格式:dic(字典对象)    含义:schema的原始格式
# 功能:判断schema是否合理,若不合理的话,会收集错误信息
# 返回值:1. bool类型,schema是否合理    2.str(字符串类型),若schema不合理,则它会携带错误信息
def _validate_schema(self, content):
        """
        Validate schema and return validation result and error message.

        Args:
           content (dict): Dict of raw schema.

        Returns:
            bool, whether the schema is valid.
            str, error message.
        """
        error = ''
        # 判断content是否为空
        if not content:
            error = 'Schema content is empty.'
            return False, error
        # 判断content是否为字典对象
        if isinstance(content, dict) is False:
            error = 'Schema content should be dict.'
            return False, error
        # 判断content是否按照schema的格式规则,以及是否符合字典的构造规则
        for k, v in content.items():
            if not re.match(r'^[0-9a-zA-Z\_]+$', k):
                error = "Field '{}' should be composed of " \
                        "'0-9' or 'a-z' or 'A-Z' or '_'.".format(k)
                return False, error
            if v and isinstance(v, dict):
                if len(v) == 1 and 'type' in v:
                    if v['type'] not in VALID_ATTRIBUTES:
                        error = "Field '{}' contain illegal " \
                                "attribute '{}'.".format(k, v['type'])
                        return False, error
                elif len(v) == 2 and 'type' in v:
                    res_1, res_2 = self._validate_array(k, v)
                    if res_1 is not True:
                        return res_1, res_2
                else:
                    error = "Field '{}' contains illegal attributes.".format(v)
                    return False, error
            else:
                error = "Field '{}' should be dict.".format(k)
                return False, error
        return True, error

# 参数:content:数据格式:dic(字典对象)    含义:schema的原始格式,用户所定义的格式
#      desc:数据格式:str(字符串)    含义:描述schema的大概所用
# 功能:建立一个原始的schema(一个dict对象),用户按照schema生成具体的schema对象
# 返回值:ShardSchema类

def build_schema(self, content, desc=None):
        """
        Build raw schema to generate schema object.

        Args:
            content (dict): Dict of user defined schema.
            desc (str,optional): String of schema description.

        Returns:
           Class ShardSchema.

        Raises:
            MRMBuildSchemaError: If failed to build schema.
        """
        desc = desc if desc else ""
        schema = ms.Schema.build(desc, content)
        if not schema:
            logger.error("Failed to add build schema.")
            raise MRMBuildSchemaError
        return schema
  • **类方法:**def add_index_fields(self, index_fields):

参数: index_fields:数据格式: list[str],字符串类型的列表 含义:schema的索引字段

**功能:**可以从schema中选择索引字段以加速读取

**返回值:**MSRStatus, SUCCESS or FAILED.

异常处理(raise):

  1. ParamTypeError:如果索引字段无效
  2. MRMDefineIndexError:如果索引字段不是基元类型
  3. MRMAddIndexError–如果添加索引字段失败
  4. MRMGetMetaError–如果未设置架构或未能获取元数据

源码:

    def add_index(self, index_fields):
        
        # 判断是否符合变量名规范,判断是否为列表
        if not index_fields or not isinstance(index_fields, list):
            raise ParamTypeError('index_fields', 'list')
		# 判断是否为基元类型
        for field in index_fields:
            if field in self._header.blob_fields:
                raise MRMDefineIndexError("Failed to set field {} since it's not primitive type.".format(field))
            if not isinstance(field, str):
                raise ParamTypeError('index field', 'str')
        # 调用调用ShardHeader类的add_index_fields方法,返回MSRStatus, SUCCESS or FAILED
        # 若为FAILED,会打印错误信息:“Failed to add index field.”
        return self._header.add_index_fields(index_fields)

  • **类方法:**def write_raw_data(self, raw_data, parallel_writer=False):

**参数:**raw_data:数据格式:list[dict] 含义:原始数据列表

​ parallel_writer:数据格式:bool(默认值为False) 含义:是否并行写入数据

**功能:**根据原始数据raw_data与schema的相应格式验证后,将原始数据转化为一系列连续的MindRecord文件

返回值: SUCCESS or FAILED

源码:

 def write_raw_data(self, raw_data, parallel_writer=False):
        if not self._writer.is_open:
            self._writer.open(self._paths)
        if not self._writer.get_shard_header():
            self._writer.set_shard_header(self._header)
        if not isinstance(raw_data, list):
            raise ParamTypeError('raw_data', 'list')
        for each_raw in raw_data:
            if not isinstance(each_raw, dict):
                raise ParamTypeError('raw_data item', 'dict')
        self._verify_based_on_schema(raw_data)
        return self._writer.write_raw_data(raw_data, True, parallel_writer)

   # 参数:raw_data:数据格式:list[dict]    含义:原始数据列表
   # 功能:根据schema验证每一行数据,若验证失败,则删除无效数据
   # 允许的数据类型包括:“int32”、“int64”、“float32”、“float64”、“string”、“bytes”
   def _verify_based_on_schema(self, raw_data):
        error_data_dic = {}
        schema_content = self._header.schema
        for field in schema_content:
            for i, v in enumerate(raw_data):
                if i in error_data_dic:
                    continue

                if field not in v:
                    error_data_dic[i] = "for schema, {} th data is wrong, " \
                                        "there is not '{}' object in the raw data.".format(i, field)
                    continue
                field_type = type(v[field]).__name__
                if field_type not in VALUE_TYPE_MAP:
                    error_data_dic[i] = "for schema, {} th data is wrong, " \
                                        "data type for '{}' is not matched.".format(i, field)
                    continue

                if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]:
                    error_data_dic[i] = "for schema, {} th data is wrong, " \
                                        "data type for '{}' is not matched.".format(i, field)
                    continue

                if field_type == 'ndarray':
                    if 'shape' not in schema_content[field]:
                        error_data_dic[i] = "for schema, {} th data is wrong, " \
                                            "data type for '{}' is not matched.".format(i, field)
                    else:
                        try:
                            np.reshape(v[field], schema_content[field]['shape'])
                        except ValueError:
                            error_data_dic[i] = "for schema, {} th data is wrong, " \
                                                "data type for '{}' is not matched.".format(i, field)
        error_data_dic = sorted(error_data_dic.items(), reverse=True)
        for i, v in error_data_dic:
            raw_data.pop(i)
            logger.warning(v)

实例

代码:

from mindspore.mindrecord import FileWriter

# 定义一个schema,为双重字典的嵌套,file_name对{"type": "string"},label对{"type": "int32"}
# data对{"type": "bytes"},我们需要将写入的数据data按照schema的格式给出
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}

# data为要写入MindRecord的数据,可以将其与cv_schema_json比较验证它的格式,图片以二进制流给出
data = [{"file_name": "1.jpg", "label": 0, "data": b"\x10c\xb3w\xa8\xee$o&<q\x8c\x8e(\xa2\x90\x90\x96\xbc\xb1\x1e\xd4QER\x13?\xff\xd9"},
        {"file_name": "2.jpg", "label": 56, "data": b"\xe6\xda\xd1\xae\x07\xb8>\xd4\x00\xf8\x129\x15\xd9\xf2q\xc0\xa2\x91YFUO\x1dsE1\x1ep"},
        {"file_name": "3.jpg", "label": 99, "data": b"\xaf\xafU<\xb8|6\xbd}\xc1\x99[\xeaj+\x8f\x84\xd3\xcc\xa0,i\xbb\xb9-\xcdz\xecp{T\xb1\xdb"}]

# 添加索引
indexes = ["file_name", "label"]

# 创建4个MindRecod文件,又含有索引,所以最后又8个MindRecord文件,分别为test.mindrecord0、
# test.mindrecord0.db、test.mindrecord1、test.mindrecord1.db、test.mindrecord2、
# test.mindrecord2.db、test.mindrecord3、test.mindrecord3.db
# 它被称为MindSpore数据集,test.mindrecord0为数据文件,test.mindrecord0.db为索引文件
writer = FileWriter(file_name="test.mindrecord", shard_num=4)
writer.add_schema(cv_schema_json, "test_schema")
writer.add_index(indexes)
writer.write_raw_data(data)
# 将最终内存写入磁盘
writer.commit()

运行截图:

在这里插入图片描述
可以看到运行结果,为我们在FileWriter类方法介绍中提到的MSRStatus SUCCESS,表示写入数据写入磁盘成功,那我们来看一下它在磁盘中的显示。
在这里插入图片描述
源码地址:https://gitee.com/mindspore/mindspore/tree/master/mindspore/mindrecord

好了到这里本篇文章就结束了,感谢大家的阅读。下篇文章我们再见。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ZW钟文

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值