在开始之前,首先声明本篇文章参考官方文档和编程指南,我基于官网的文章加以自己的理解发表了这篇博客,希望大家能够更快更简单直观的体验MindSpore,如有不妥的地方欢迎大家指正。
希望大家为我点个赞,码字不易啊。
【本文代码编译环境为MindSpore1.3.0 CPU版本】
经过手写数字识别初体验的介绍,我想大家对于mindspore文件夹的各个模块的功能已经有了大概的了解。在后续的文章中,我将按照训练一个神经网络的步骤,从数据集加载开始直到模型的成功验证推理,逐个地摸索每个模块的功能。在本篇文章中,我们来看一下MindSpore自定义数据格式MindRecord,它的一些主要的方法在mindspore.mindrecord模块。使用MindRecord格式的数据去训练网络,可以获得更好的性能提升。
MindRecord具备的特征如下:
- 实现多变的用户数据统一存储、访问,训练数据读取更加简便。
- 数据聚合存储,高效读取,且方便管理、移动。
- 高效的数据编解码操作,对用户透明、无感知。
- 可以灵活控制分区的大小,实现分布式训练。
我们使用MindRecord的目标是归一化提供训练测试所用的数据集,并通过dataset模块的相关方法进行数据的读取,将这些高效的数据投入训练。
在mindspore.mindrecord模块中定义了一个专门的类FileWriter将用户定义的原始数据写入MindRecord文件。这个类的源码和相关方法的功能将是本篇文章的重要内容。
个人理解
首先,我认为FileWriter类是一个统筹者,它的功能是将原始数据文件写入MindRecord文件,但有一些基本的功能,如打开一个MindRecord文件,将数据写入MindRecord文件,关闭MindRecord文件。它们很重要,但他们并不是FileWriter类实现的,有三个专门的类来实现这些功能。一个是ShardHeader类:在C++模块中代表SHDDHead类的包装类。该类将存储MindRecord文件的元数据。另一个是SharedWriter类:在C++模块中代表SARDWRITE类的包装类。该类将将验证处理好的数据写入MindRecord文件。还有一个是ShardReader类:在C++模块中代表SARDRADER类的包装类。该类可以打开或关闭MindRecord文件,并且可以从MindRecord文件系列中读取一批数据。这三个类相当于是三个具有特定功能的小兵,而FileWriter类则可以看成一个厉害的将军,可以合理的调配三个拥有特定能力的小兵,以实现将原始数据写入MindRecord文件的任务。
其次,我们来看一下如何通过FileWriter类的调度来实现数据写入MindRecord文件的操作。首先,我们的实例化创建一个FileWriter对象,在对象的初始化过程中,我们会声明好一个SharedHeader对象和一个ShareWriter对象。然后我们调用FileWriter对象的add_schema方法,这里我告诉你这个add_schema方法是用来声明数据格式模板的,但实际上FileWriter类中并没有实现这个方法的具体功能,只是提供或处理了一些参数,这个时候它就调用了初始化过程中声明的SharedHeader对象的add_schema方法,同理,我们将数据写入MindRecord文件要调用FileWriter对象的write_raw_data方法,FileWriter类中这个方法,实际上也只提供了参数而没有实现具体功能,我们需要调用SharedWriter类的write_raw_data方法。ShareReader类中的方法同样会在FileWriter类中的某些方法中被调用。
这样做的的话,很多的体现了代码的封装性。它既减少了FileWriter类的代码量,也更好的文件读取,写入等功能封装在各个类中,对于项目具有很强的友好性。
FileWriter主要方法介绍
class FileWriter:
- 类的初始化方法def init(self, file_name, shard_num=1):
参数:file_name:数据格式:str(字符串) 含义:原始数据要写入MindRecord文件夹的名字
shard_num:数据格式:int(整型) 含义:生成MindRecord文件的个数,取值范围为[1, 1000],默认值为1
源码:
def __init__(self, file_name, shard_num=1):
#检查路径中的文件名(file_name)是否可用
check_filename(file_name)
self._file_name = file_name
#检查shard_num是否合法,包括是否为空值,是否为int类型,是否处在取值范围内
if shard_num is not None:
if isinstance(shard_num, int):
if shard_num < MIN_SHARD_COUNT or shard_num > MAX_SHARD_COUNT:
raise ParamValueError("Shard number should between {} and {}."
.format(MIN_SHARD_COUNT, MAX_SHARD_COUNT))
else:
raise ParamValueError("Shard num is illegal.")
else:
raise ParamValueError("Shard num is illegal.")
self._shard_num = shard_num
self._index_generator = True
suffix_shard_size = len(str(self._shard_num - 1))
if self._shard_num == 1:
self._paths = [self._file_name]
else:
self._paths = ["{}{}".format(self._file_name,
str(x).rjust(suffix_shard_size, '0'))
for x in range(self._shard_num)]
self._append = False
self._header = ShardHeader()
self._writer = ShardWriter()
self._generator = None
注意:一旦初始化方法生效之后,尽量不要更改生成的MindRecord文件的文件名,因为文件名更改可能造成文件无法读取。
- 类方法:def open_for_append(cls, file_name):
参数:file_name:数据格式:str(字符串) 含义:MindRecord文件夹名称
功能:打开MindRecord文件,准备添加数据
返回值:打开的MindRecord文件的file writer对象
异常处理(raise):
- ParamValueError(传入参数错误):如果文件名无效
- FileNameError:如果路径包含无效字符
- MRMOpenError:如果无法打开MindRecord文件。
- mrmopenfrappenderror:如果无法打开附加数据的文件。
源码:
@classmethod
def open_for_append(cls, file_name):
check_filename(file_name)
# construct ShardHeader
# ShardReader类:在C++模块中代表SARDRADER类的包装类。该类将
# 从MindRecord文件系列中读取一批数据。
# open方法:打开文件并准备读取MindRecord文件。
# close方法:关闭MindRecord文件
# ShardHeader类:在C++模块中代表SHDDHead类的包装类。该类将
# 存储MindRecord文件的元数据。初始化时需传入参数header,默认为None
# get_header方法:返回MindRecord文件的head
reader = ShardReader()
reader.open(file_name)
header = ShardHeader(reader.get_header())
reader.close()
instance = cls("append")
instance.init_append(file_name, header)
return instance
def init_append(self, file_name, header):
#将默认参数_append(False)改为True,表示MindRecord文件打开成功
#将self._header实例化为已打开MindRecord文件的对象
self._append = True
self._file_name = file_name
self._header = header
self._writer.open_for_append(file_name)
- 类方法:def add_schema(self, content, desc=None):
**参数:**content:数据格式:dict(字典对象) 含义:模式内容的字典
desc: 数据格式:str(字符串) 含义:架构描述,默认值:None
**功能:**添加模式(字典的形式)来描述需要写入的原始数据的形式
**返回值:**一个int整型数,表示该schema对于的id
异常处理(raise):
- MRMInvalidSchemaError:如果所写的schema无效,可能时格式错误
- MRMBuildSchemaError:如果未能生成schema
- MRMAddSchemaError:如果未能田间schema
源码:
def add_schema(self, content, desc=None):
# 调用self._validate_schema判断schema是否合理,ret为Ture or False,error_msg
# 为错误信息
ret, error_msg = self._validate_schema(content)
# 若ret为False,即schema不合理,raise异常处理——1
if ret is False:
raise MRMInvalidSchemaError(error_msg)
# 调用self._header.build_schema以一个合理的原始的schema生成具体的schema对象,
# 此处的schema变量是ShardSchema类的对象
schema = self._header.build_schema(content, desc)
# 调用ShardHeader类的add_schema方法,为ShareHeader添加一个schema对象,并返回它的id
return self._header.add_schema(schema)
# 参数:content:数据格式:dic(字典对象) 含义:schema的原始格式
# 功能:判断schema是否合理,若不合理的话,会收集错误信息
# 返回值:1. bool类型,schema是否合理 2.str(字符串类型),若schema不合理,则它会携带错误信息
def _validate_schema(self, content):
"""
Validate schema and return validation result and error message.
Args:
content (dict): Dict of raw schema.
Returns:
bool, whether the schema is valid.
str, error message.
"""
error = ''
# 判断content是否为空
if not content:
error = 'Schema content is empty.'
return False, error
# 判断content是否为字典对象
if isinstance(content, dict) is False:
error = 'Schema content should be dict.'
return False, error
# 判断content是否按照schema的格式规则,以及是否符合字典的构造规则
for k, v in content.items():
if not re.match(r'^[0-9a-zA-Z\_]+$', k):
error = "Field '{}' should be composed of " \
"'0-9' or 'a-z' or 'A-Z' or '_'.".format(k)
return False, error
if v and isinstance(v, dict):
if len(v) == 1 and 'type' in v:
if v['type'] not in VALID_ATTRIBUTES:
error = "Field '{}' contain illegal " \
"attribute '{}'.".format(k, v['type'])
return False, error
elif len(v) == 2 and 'type' in v:
res_1, res_2 = self._validate_array(k, v)
if res_1 is not True:
return res_1, res_2
else:
error = "Field '{}' contains illegal attributes.".format(v)
return False, error
else:
error = "Field '{}' should be dict.".format(k)
return False, error
return True, error
# 参数:content:数据格式:dic(字典对象) 含义:schema的原始格式,用户所定义的格式
# desc:数据格式:str(字符串) 含义:描述schema的大概所用
# 功能:建立一个原始的schema(一个dict对象),用户按照schema生成具体的schema对象
# 返回值:ShardSchema类
def build_schema(self, content, desc=None):
"""
Build raw schema to generate schema object.
Args:
content (dict): Dict of user defined schema.
desc (str,optional): String of schema description.
Returns:
Class ShardSchema.
Raises:
MRMBuildSchemaError: If failed to build schema.
"""
desc = desc if desc else ""
schema = ms.Schema.build(desc, content)
if not schema:
logger.error("Failed to add build schema.")
raise MRMBuildSchemaError
return schema
- **类方法:**def add_index_fields(self, index_fields):
参数: index_fields:数据格式: list[str],字符串类型的列表 含义:schema的索引字段
**功能:**可以从schema中选择索引字段以加速读取
**返回值:**MSRStatus, SUCCESS or FAILED.
异常处理(raise):
- ParamTypeError:如果索引字段无效
- MRMDefineIndexError:如果索引字段不是基元类型
- MRMAddIndexError–如果添加索引字段失败
- MRMGetMetaError–如果未设置架构或未能获取元数据
源码:
def add_index(self, index_fields):
# 判断是否符合变量名规范,判断是否为列表
if not index_fields or not isinstance(index_fields, list):
raise ParamTypeError('index_fields', 'list')
# 判断是否为基元类型
for field in index_fields:
if field in self._header.blob_fields:
raise MRMDefineIndexError("Failed to set field {} since it's not primitive type.".format(field))
if not isinstance(field, str):
raise ParamTypeError('index field', 'str')
# 调用调用ShardHeader类的add_index_fields方法,返回MSRStatus, SUCCESS or FAILED
# 若为FAILED,会打印错误信息:“Failed to add index field.”
return self._header.add_index_fields(index_fields)
- **类方法:**def write_raw_data(self, raw_data, parallel_writer=False):
**参数:**raw_data:数据格式:list[dict] 含义:原始数据列表
parallel_writer:数据格式:bool(默认值为False) 含义:是否并行写入数据
**功能:**根据原始数据raw_data与schema的相应格式验证后,将原始数据转化为一系列连续的MindRecord文件
返回值: SUCCESS or FAILED
源码:
def write_raw_data(self, raw_data, parallel_writer=False):
if not self._writer.is_open:
self._writer.open(self._paths)
if not self._writer.get_shard_header():
self._writer.set_shard_header(self._header)
if not isinstance(raw_data, list):
raise ParamTypeError('raw_data', 'list')
for each_raw in raw_data:
if not isinstance(each_raw, dict):
raise ParamTypeError('raw_data item', 'dict')
self._verify_based_on_schema(raw_data)
return self._writer.write_raw_data(raw_data, True, parallel_writer)
# 参数:raw_data:数据格式:list[dict] 含义:原始数据列表
# 功能:根据schema验证每一行数据,若验证失败,则删除无效数据
# 允许的数据类型包括:“int32”、“int64”、“float32”、“float64”、“string”、“bytes”
def _verify_based_on_schema(self, raw_data):
error_data_dic = {}
schema_content = self._header.schema
for field in schema_content:
for i, v in enumerate(raw_data):
if i in error_data_dic:
continue
if field not in v:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"there is not '{}' object in the raw data.".format(i, field)
continue
field_type = type(v[field]).__name__
if field_type not in VALUE_TYPE_MAP:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
continue
if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
continue
if field_type == 'ndarray':
if 'shape' not in schema_content[field]:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
else:
try:
np.reshape(v[field], schema_content[field]['shape'])
except ValueError:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
error_data_dic = sorted(error_data_dic.items(), reverse=True)
for i, v in error_data_dic:
raw_data.pop(i)
logger.warning(v)
实例
代码:
from mindspore.mindrecord import FileWriter
# 定义一个schema,为双重字典的嵌套,file_name对{"type": "string"},label对{"type": "int32"}
# data对{"type": "bytes"},我们需要将写入的数据data按照schema的格式给出
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
# data为要写入MindRecord的数据,可以将其与cv_schema_json比较验证它的格式,图片以二进制流给出
data = [{"file_name": "1.jpg", "label": 0, "data": b"\x10c\xb3w\xa8\xee$o&<q\x8c\x8e(\xa2\x90\x90\x96\xbc\xb1\x1e\xd4QER\x13?\xff\xd9"},
{"file_name": "2.jpg", "label": 56, "data": b"\xe6\xda\xd1\xae\x07\xb8>\xd4\x00\xf8\x129\x15\xd9\xf2q\xc0\xa2\x91YFUO\x1dsE1\x1ep"},
{"file_name": "3.jpg", "label": 99, "data": b"\xaf\xafU<\xb8|6\xbd}\xc1\x99[\xeaj+\x8f\x84\xd3\xcc\xa0,i\xbb\xb9-\xcdz\xecp{T\xb1\xdb"}]
# 添加索引
indexes = ["file_name", "label"]
# 创建4个MindRecod文件,又含有索引,所以最后又8个MindRecord文件,分别为test.mindrecord0、
# test.mindrecord0.db、test.mindrecord1、test.mindrecord1.db、test.mindrecord2、
# test.mindrecord2.db、test.mindrecord3、test.mindrecord3.db
# 它被称为MindSpore数据集,test.mindrecord0为数据文件,test.mindrecord0.db为索引文件
writer = FileWriter(file_name="test.mindrecord", shard_num=4)
writer.add_schema(cv_schema_json, "test_schema")
writer.add_index(indexes)
writer.write_raw_data(data)
# 将最终内存写入磁盘
writer.commit()
运行截图:
可以看到运行结果,为我们在FileWriter类方法介绍中提到的MSRStatus SUCCESS,表示写入数据写入磁盘成功,那我们来看一下它在磁盘中的显示。
源码地址:https://gitee.com/mindspore/mindspore/tree/master/mindspore/mindrecord
好了到这里本篇文章就结束了,感谢大家的阅读。下篇文章我们再见。