预备知识
预备知识,需要了解python 可迭代对象 和 迭代器 。
可迭代对象:是指能够一次返回其中一个成员的对象,通常使用for循环 来完成此操作,如字符串、列表、元组、集合、字典等等之类的对象都属于可迭代对象。实现__getitem__(self, idx)函数的class就是一个可迭代对象。
MindSpore实现自定义数据集加载需要两个步骤:
假设当前需要读取一个CSV文件,格式如下:第一列为文本,第二列为标签,两列的标题分别为review,label。
步骤1:定义一个可迭代对象
实现__getitem__(self, idx)函数的class就是一个可迭代对象,下面对象通过_load(self)函数读取数据集,再通过__getitem__(self, idx)函数返回数据集。
class MyCSVData():
"""
文本CSV数据集加载器
加载数据集并处理为一个Python迭代对象。
"""
def __init__(self, path):
self.path = path
self.review, self.label = [], []
self._load()
def _load(self):
# 根据self.path load 需要读的csv文件
with open(self.path, "r") as csv_file:
dict_reader = csv.DictReader(csv_file)
# 按行读取
for row in dict_reader:
review = row['review']
label = int(np.float32(row['label']))
# 数据处理
label_onehot = [0] * 5
label_onehot[label - 1] = 1
review = re.split(' |,', review.lower())
# 数据加载到List
self.review.append(review)
self.label.append(label_onehot)
def __getitem__(self, idx):
"""
定义可迭代对象返回当前结果的逻辑
"""
return self.review[idx], self.label[idx]
def __len__(self):
"""
返回可迭代对象的长度
:return: int
"""
return len(self.review)
步骤2: 加载至GeneratorDataset(用于生成MindSpore数据对象)
GeneratorDataset会返回一个Dataset对象,该对象在后续训练中可以用于MindSpore训练及预测等操作,如下代码表示基于MyCSVData生成一个Dataset对象。
import mindspore.dataset as dataset
csv_path = "./dataset/test1.csv"
data_train = dataset.GeneratorDataset(MyCSVData(csv_path),
column_names=["review", "label"])
通过以上两个步骤,就可以将任意格式的数据集读取为MindSpore的Dataset对象。
打印数据
data = next(data_train.create_dict_iterator())
print(data["review"])
print(data["label"])
分割验证集
data_train, data_valid = data_train.split([0.7, 0.3])
设置Batch Size
BATCH_SIZE = 128
data_train = data_train.batch(BATCH_SIZE, drop_remainder=True)
data_valid = data_valid.batch(BATCH_SIZE, drop_remainder=True)
训练
经过如上处理的数据可以直接基于MindSpore进行训练
model.train(num_epochs,
data_train,
callbacks=[
ValAccMonitor(model, data_valid, num_epochs, ckpt_directory='./model/')
])