MindSpore易点通·精讲系列--数据集加载之TFRecordDataset

Dive Into MindSpore – TFRecordDataset For Dataset Load

MindSpore易点通·精讲系列–数据集加载之TFRecordDataset

本文开发环境

  • Ubuntu 20.04
  • Python 3.8
  • MindSpore 1.7.0

本文内容摘要

  • 背景介绍
  • 先看文档
  • 生成TFRecord
  • 数据加载
  • 本文总结
  • 本文参考

1. 背景介绍

TFRecord格式是TensorFlow官方设计的一种数据格式。

TFRecord 格式是一种用于存储二进制记录序列的简单格式,该格式能够更好的利用内存,内部包含多个tf.train.Example,在一个Examples消息体中包含一系列的tf.train.feature属性,而每一个feature是一个key-value的键值对,其中key是string类型,value的取值有三种:

  • bytes_list:可以存储stringbyte两种数据类型
  • float_list:可以存储float(float32)double(float64)两种数据类型
  • int64_list:可以存储bool, enum, int32, uint32, int64, uint64数据类型

上面简单介绍了TFRecord的知识,下面我们就要进入正题,来谈谈MindSpore中对TFRecord格式的支持。

2. 先看文档

老传统,先来看看官方对API的描述。
api

下面对主要参数做简单介绍:

  • dataset_files – 数据集文件路径。
  • schema – 读取模式策略,通俗来说就是要读取的tfrecord文件内的数据内容格式。可以通过json或者Schema传入。默认为None不指定。
  • columns_list – 指定读取的具体数据列。默认全部读取。
  • num_samples – 指定读取出来的样本数量。
  • shuffle – 是否对数据进行打乱,可参考之前的文章解读。

3. 生成TFRecord

本文使用的是THUCNews数据集,如果需要将该数据集用于商业用途,请联系数据集作者

数据集启智社区下载地址

由于下文需要用到TFRecord数据集来做加载,本节先来生成TFRecord数据集。对TensorFlow不了解的读者可以直接照搬代码即可。

生成TFRecord代码如下:

import codecs
import os
import re
import six
import tensorflow as tf

from collections import Counter


def _int64_feature(values):
    """Returns a TF-Feature of int64s.

    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature.
    """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def _float32_feature(values):
    """Returns a TF-Feature of float32s.

    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature.
     """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(float_list=tf.train.FloatList(value=values))


def _bytes_feature(values):
    """Returns a TF-Feature of bytes.
    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature
    """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))


def convert_to_feature(values):
    """Convert to TF-Feature based on the type of element in values.

    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature.
    """
    if not isinstance(values, (tuple, list)):
        values = [values]

    if isinstance(values[0], int):
        return _int64_feature(values)
    elif isinstance(values[0], float):
        return _float32_feature(values)
    elif isinstance(values[0], bytes):
        return _bytes_feature(values)
    else:
        raise ValueError("feature type {0} is not supported now !".format(type(values[0])))


def dict_to_example(dictionary):
    """Converts a dictionary of string->int to a tf.Example."""
    features = {}
    for k, v in six.iteritems(dictionary):
        features[k] = convert_to_feature(values=v)
    return tf.train.Example(features=tf.train.Features(feature=features))


def get_txt_files(data_dir):
    cls_txt_dict = {}
    txt_file_list = []

    # get files list and class files list.
    sub_data_name_list = next(os.walk(data_dir))[1]
    sub_data_name_list = sorted(sub_data_name_list)
    for sub_data_name in sub_data_name_list:
        sub_data_dir = os.path.join(data_dir, sub_data_name)
        data_name_list = next(os.walk(sub_data_dir))[2]
        data_file_list = [os.path.join(sub_data_dir, data_name) for data_name in data_name_list]
        cls_txt_dict[sub_data_name] = data_file_list
        txt_file_list.extend(data_file_list)
        num_data_files = len(data_file_list)
        print("{}: {}".format(sub_data_name, num_data_files), flush=True)
    num_txt_files = len(txt_file_list)
    print("total: {}".format(num_txt_files), flush=True)

    return cls_txt_dict, txt_file_list


def get_txt_data(txt_file):
    with codecs.open(txt_file, "r", "UTF8") as fp:
        txt_content = fp.read()
    txt_data = re.sub("\s+", " ", txt_content)

    return txt_data


def build_vocab(txt_file_list, vocab_size=7000):
    counter = Counter()
    for txt_file in txt_file_list:
        txt_data = get_txt_data(txt_file)
        counter.update(txt_data)

    num_vocab = len(counter)
    if num_vocab < vocab_size - 1:
        real_vocab_size = num_vocab + 2
    else:
        real_vocab_size = vocab_size

    # pad_id is 0, unk_id is 1
    vocab_dict = {word_freq[0]: ix + 1 for ix, word_freq in enumerate(counter.most_common(real_vocab_size - 2))}

    print("real vocab size: {}".format(real_vocab_size), flush=True)
    print("vocab dict:\n{}".format(vocab_dict), flush=True)

    return vocab_dict


def make_tfrecords(
        data_dir, tfrecord_dir, vocab_size=7000, min_seq_length=10, max_seq_length=800,
        num_train=8, num_test=2, start_fid=0):
    # get txt files
    cls_txt_dict, txt_file_list = get_txt_files(data_dir=data_dir)
    # map word to id
    vocab_dict = build_vocab(txt_file_list=txt_file_list, vocab_size=vocab_size)
    # map class to id
    class_dict = {class_name: ix for ix, class_name in enumerate(cls_txt_dict.keys())}

    train_writers = []
    for fid in range(start_fid, num_train+start_fid):
        tfrecord_file = os.path.join(tfrecord_dir, "train_{:04d}.tfrecord".format(fid))
        writer = tf.io.TFRecordWriter(tfrecord_file)
        train_writers.append(writer)

    test_writers = []
    for fid in range(start_fid, num_test+start_fid):
        tfrecord_file = os.path.join(tfrecord_dir, "test_{:04d}.tfrecord".format(fid))
        writer = tf.io.TFRecordWriter(tfrecord_file)
        test_writers.append(writer)

    pad_id = 0
    unk_id = 1
    num_samples = 0
    num_train_samples = 0
    num_test_samples = 0
    for class_name, class_file_list in cls_txt_dict.items():
        class_id = class_dict[class_name]
        num_class_pass = 0
        for txt_file in class_file_list:
            txt_data = get_txt_data(txt_file=txt_file)
            txt_len = len(txt_data)
            if txt_len < min_seq_length:
                num_class_pass += 1
                continue
            if txt_len > max_seq_length:
                txt_data = txt_data[:max_seq_length]
                txt_len = max_seq_length
            word_ids = []
            for word in txt_data:
                word_id = vocab_dict.get(word, unk_id)
                word_ids.append(word_id)
            for _ in range(max_seq_length - txt_len):
                word_ids.append(pad_id)

            example = dict_to_example({"input": word_ids, "class": class_id})
            num_samples += 1
            if num_samples % 10 == 0:
                num_test_samples += 1
                writer_id = num_test_samples % num_test
                test_writers[writer_id].write(example.SerializeToString())
            else:
                num_train_samples += 1
                writer_id = num_train_samples % num_train
                train_writers[writer_id].write(example.SerializeToString())
        print("{} pass: {}".format(class_name, num_class_pass), flush=True)

    for writer in train_writers:
        writer.close()
    for writer in test_writers:
        writer.close()

    print("num samples: {}".format(num_samples), flush=True)
    print("num train samples: {}".format(num_train_samples), flush=True)
    print("num test samples: {}".format(num_test_samples), flush=True)


def main():
    data_dir = "{your_data_dir}"
    tfrecord_dir = "{your_tfrecord_dir}"
    make_tfrecords(data_dir=data_dir, tfrecord_dir=tfrecord_dir)


if __name__ == "__main__":
    main()

将以上代码保存到文件make_tfrecord.py,运行命令:

注意:需要替换data_dirtfrecord_dir为个人目录。

python3 make_tfrecord.py

使用tree命令查看生成的TFRecord数据目录,输出内容如下:

.
├── test_0000.tfrecord
├── test_0001.tfrecord
├── train_0000.tfrecord
├── train_0001.tfrecord
├── train_0002.tfrecord
├── train_0003.tfrecord
├── train_0004.tfrecord
├── train_0005.tfrecord
├── train_0006.tfrecord
└── train_0007.tfrecord

0 directories, 10 files

4. 数据加载

有了3中的TFRecord数据集,下面来介绍如何在MindSpore中使用该数据集。

4.1 schema使用

4.1.1 不指定schema

首先来看看对于参数schema不指定,即采用默认值的情况下,能否正确读取数据。

代码如下:

import os

from mindspore.common import dtype as mstype
from mindspore.dataset import Schema
from mindspore.dataset import TFRecordDataset


def get_tfrecord_files(tfrecord_dir, file_suffix="tfrecord", is_train=True):
    if not os.path.exists(tfrecord_dir):
        raise ValueError("tfrecord directory: {} not exists!".format(tfrecord_dir))

    if is_train:
        file_prefix = "train"
    else:
        file_prefix = "test"

    data_sources = []
    for parent, _, filenames in os.walk(tfrecord_dir):
        for filename in filenames:
            if not filename.startswith(file_prefix):
                continue
            tmp_path = os.path.join(parent, filename)
            if tmp_path.endswith(file_suffix):
                data_sources.append(tmp_path)
    return data_sources


def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    dataset = TFRecordDataset(dataset_files=tfrecord_files, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break


def main():
    tfrecord_dir = "{your_tfrecord_dir}"
    tfrecord_json = "{your_tfrecord_json_file}"
    load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=None)


if __name__ == "__main__":
    main()

代码解读:

  • get_tfrecord_files – 获取指定的TFRecord文件列表
  • load_tfrecord – 数据集加载

将上述代码保存到文件load_tfrecord_dataset.py,运行如下命令:

python3 load_tfrecord_dataset.py

输出内容如下:

可以看出能正确解析出之前保存在TFRecord内的数据,数据类型和数据维度解析正确。

{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18,  742,  330,  385,  999,  837,   56,  529, 1000,  260,    3,  171,   45,    7,   65,  136,  869,  211,  215,  443,  541,
    3,   91, 1719,  636,    2,  424,  291,   16,   86,   31,   12,  211,  215,  443,  541,  999,  322,  128,  916,  102,  743,  136,  121,  298,
  454,    2,  234,  225,    1,  136,  121,  298,  454,  100,   49,   22,  152,   70,  677,  806,   31, 1719,  636,  100,   25,  237,    2,  424,
    1,   39,  100,   39,   71,  228,  385,  999,  837,    1,  171,   45,  136,  869,  211,  215,  443,  541,    1,   35,   68,   20,  149,  304,
   31,   70,  677,    1,  106,    9,  308,  487,  869,  153,  597,    2,  523,  262,  184,  145,   57,   36,  158,   13,   69,   41,    1,   35,
   68,  511,  402,  152,  469,   41,  617,  761,   50,   36,  144,  281,   26,  308,  487,  869,  153,  597,    4,   23,  208,   17,  121,  428,
  646,   71,    8,   10,   47,   40,   87,   32,  413,  133,    9,  641,  159,   74,  144,  281,   26,  308,  487,  869,  153,  597,    2,  197,
  447,    1,   91,  549,  202,  208,   17,  121,  558,  123,    2,  113,  203,    1,  419, 1024,  200,  154,   80,   16,  147,   64,  111,  208,
  219,  136,   25,    6,  153,  597,    4,  160,  134,   16,    1,  167,  229, 1719,  636,  514,    2,    9,    7,   65,  321,  136,  869,  211,
  215,  443,  541,    1,  514,    2,   69,   33,   13,   88,   80,   94,  294,    2,  308,  487,  869,  153,  597,    1,   39,   69,   33,  197,
   57,  310,  335,   50,   94,  294,    2,  308,  487,  869,  153,  597,    1,  221,   13,   74,  337,   56,  499,  117,  836,  621,  488,   26,
   94,  294,    1,   10,    5,    7,   10,   21,  973,  124,  492,   69,   33,  514,  218,  168,  117,    1,   82,  285,  148,  697,    2,  982,
  298, 1535,  119,  743,  201, 1187,    4,    3,  136,  121,  298,  454,  103,  752,   31,   12, 1496,  762,  164,    2,  609,    6,  175,   83,
  170,  257,  454,  963,    1,  149,   57,  136,  121,  298,  454,   62,   52,   87,  110,  257,   12,   34,   39,    2,  677, 1151,    1,  136,
  121,  298,  454,  100,   49,   22,  138,   55,   39,    1,  752,  744,  184,   36,  169,   11,  561,    9,    1,   13,   74,   39,   62,    9,
  308,  487,  869,  250,  321,    4,   23,  211,  215,  443,  541,    9,  641,   32,  900, 2586,   83, 1157,  165,  978,   97,  694,  837,  301,
   22,   97,  694,  837,    9,  124,  492,    2,  720, 1341,    1,   35,   68,   32, 2294,  216,    2,    1,  106,    9,   13,    9,   20,   12,
   25,   26,  973,  124,    1,   91,    9,   20,   12,  344,   36,    4,   23,    3,  167,  229,  211,  215,  443,  541,    2,  283,  683,    9,
 1719,  636,    1,  292,  278,    9,  641,  103,   32,  283,  683,  976,  944,  511,  316,   30,  178,  223,  795,  136,  164,  301,   22,   25,
  172,   26,   18, 1102,   69,   41,  136,  869,    1,   35,   68,  184,  344,  285,   74,    1,  178,  223,  795,  136,  164,   13,    9,    6,
   26, 1152,  285,    1,  163,   20,   35,   68,  184,  344,  165,  894,   74,  521,   96,   39,    1,  976,  944,  511,  316,   62,    9,  167,
  149,    1, 1024, 1405,  164,  271,  454,  102,  743,   62,    9,   25,  278,  100,    2,    4,   23,    3,  136,  121,  298,  454,  103,  752,
   31,   12,  145,   57,  442,   32,  401,  665,   14,    2,  432,  848,  808,   49,   22,  432,  848,  808,   30,   35,   68,    2,  116,   39,
   57,  896,    6,  237,    1,  112,    9,  508,  922,    2,   83,  479,    1,  106,   35,   13,  382,  203,   39,    9,  641,   32,   96,  168,
   19,   59,  117,    1,   62,   13,  382,  203,   39,  351,   37,  309,  641,   32,  309,   51,    1,   35,   13,    9,  102,  406,  621,    4,
   23,  136,  121,  298,  454,  103,  110,  177,    1,  145,   57,    2,  211,  215,  443,  541,    1, 1171,  736,    2,    9,   14,   37,   83,
  170,    1,   22,   35,   68,    2,   14,   37,    9,  173,   45, 1652,  136,    6,   57, 1652,  516,  565,    1,   35,   68,  151, 1171,  736,
    2,    9,   14,   37,    1,   62,  513,  755,   57, 1652,    1,   91,  253,   71,   15,   45,  655,   15,   57,  896,    1,   35,   68,   13,
  318,  165,  894,    4,   23,    3,  136,  121,  298,  454,  103,   34,  145,   57,  211,  215,  443,  541,    2,  878,  503,  516,  565,  304,
   31,  648,  208,   49,   22,   83,  117,  147,   64,  219,  246,   12,  152,   66,    1, 1290,  455,  164,  154,  234,   36,   12,    1, 1000,
  316,  164,   15,  998,  812, 1289,  112,   36,   12,    1, 1426,  201,  119, 1078,  319,  512,   71,    8,  182,  124,  238,  230,  123,  901,
    1,  184,  222,    6,   87,  435,   71,   60,   20,  211,  215,  443,  541,    2,    6,  170,    1,   16,   94,  294,  475,  419, 2450,    9,
  571,   11,    1,   63,    8,    7,    5,    5,  122, 1080,   35,   68,   12,    4,  846,  337,   61,  301,  701,  297,   39,    6,  539,   27,
  135,  979,    1,   35,  166,  181,   90,  143])}
4.1.2 使用Schema对象

下面介绍,如何使用mindspore.dataset.Schema来指定读取模型策略。

修改load_tfrecord代码如下:

def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    data_schema = Schema()
    data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
    data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break

代码解读:

  • 这里使用了Schema对象,并且指定了列名,列的数据类型和数据维度。

保存并再次运行文件load_tfrecord_dataset.py,输出内容如下:

可以看出能正确解析出之前保存在TFRecord内的数据,数据类型和数据维度解析正确。

{'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18,  742,  330,  385,  999,  837,   56,  529, 1000,  260,    3,  171,   45,    7,   65,  136,  869,  211,  215,  443,  541, 
    3,   91, 1719,  636,    2,  424,  291,   16,   86,   31,   12,  211,  215,  443,  541,  999,  322,  128,  916,  102,  743,  136,  121,  298, 
  454,    2,  234,  225,    1,  136,  121,  298,  454,  100,   49,   22,  152,   70,  677,  806,   31, 1719,  636,  100,   25,  237,    2,  424, 
    1,   39,  100,   39,   71,  228,  385,  999,  837,    1,  171,   45,  136,  869,  211,  215,  443,  541,    1,   35,   68,   20,  149,  304, 
   31,   70,  677,    1,  106,    9,  308,  487,  869,  153,  597,    2,  523,  262,  184,  145,   57,   36,  158,   13,   69,   41,    1,   35, 
   68,  511,  402,  152,  469,   41,  617,  761,   50,   36,  144,  281,   26,  308,  487,  869,  153,  597,    4,   23,  208,   17,  121,  428, 
  646,   71,    8,   10,   47,   40,   87,   32,  413,  133,    9,  641,  159,   74,  144,  281,   26,  308,  487,  869,  153,  597,    2,  197, 
  447,    1,   91,  549,  202,  208,   17,  121,  558,  123,    2,  113,  203,    1,  419, 1024,  200,  154,   80,   16,  147,   64,  111,  208, 
  219,  136,   25,    6,  153,  597,    4,  160,  134,   16,    1,  167,  229, 1719,  636,  514,    2,    9,    7,   65,  321,  136,  869,  211, 
  215,  443,  541,    1,  514,    2,   69,   33,   13,   88,   80,   94,  294,    2,  308,  487,  869,  153,  597,    1,   39,   69,   33,  197, 
   57,  310,  335,   50,   94,  294,    2,  308,  487,  869,  153,  597,    1,  221,   13,   74,  337,   56,  499,  117,  836,  621,  488,   26, 
   94,  294,    1,   10,    5,    7,   10,   21,  973,  124,  492,   69,   33,  514,  218,  168,  117,    1,   82,  285,  148,  697,    2,  982, 
  298, 1535,  119,  743,  201, 1187,    4,    3,  136,  121,  298,  454,  103,  752,   31,   12, 1496,  762,  164,    2,  609,    6,  175,   83, 
  170,  257,  454,  963,    1,  149,   57,  136,  121,  298,  454,   62,   52,   87,  110,  257,   12,   34,   39,    2,  677, 1151,    1,  136, 
  121,  298,  454,  100,   49,   22,  138,   55,   39,    1,  752,  744,  184,   36,  169,   11,  561,    9,    1,   13,   74,   39,   62,    9, 
  308,  487,  869,  250,  321,    4,   23,  211,  215,  443,  541,    9,  641,   32,  900, 2586,   83, 1157,  165,  978,   97,  694,  837,  301, 
   22,   97,  694,  837,    9,  124,  492,    2,  720, 1341,    1,   35,   68,   32, 2294,  216,    2,    1,  106,    9,   13,    9,   20,   12, 
   25,   26,  973,  124,    1,   91,    9,   20,   12,  344,   36,    4,   23,    3,  167,  229,  211,  215,  443,  541,    2,  283,  683,    9, 
 1719,  636,    1,  292,  278,    9,  641,  103,   32,  283,  683,  976,  944,  511,  316,   30,  178,  223,  795,  136,  164,  301,   22,   25, 
  172,   26,   18, 1102,   69,   41,  136,  869,    1,   35,   68,  184,  344,  285,   74,    1,  178,  223,  795,  136,  164,   13,    9,    6, 
   26, 1152,  285,    1,  163,   20,   35,   68,  184,  344,  165,  894,   74,  521,   96,   39,    1,  976,  944,  511,  316,   62,    9,  167, 
  149,    1, 1024, 1405,  164,  271,  454,  102,  743,   62,    9,   25,  278,  100,    2,    4,   23,    3,  136,  121,  298,  454,  103,  752, 
   31,   12,  145,   57,  442,   32,  401,  665,   14,    2,  432,  848,  808,   49,   22,  432,  848,  808,   30,   35,   68,    2,  116,   39, 
   57,  896,    6,  237,    1,  112,    9,  508,  922,    2,   83,  479,    1,  106,   35,   13,  382,  203,   39,    9,  641,   32,   96,  168, 
   19,   59,  117,    1,   62,   13,  382,  203,   39,  351,   37,  309,  641,   32,  309,   51,    1,   35,   13,    9,  102,  406,  621,    4, 
   23,  136,  121,  298,  454,  103,  110,  177,    1,  145,   57,    2,  211,  215,  443,  541,    1, 1171,  736,    2,    9,   14,   37,   83, 
  170,    1,   22,   35,   68,    2,   14,   37,    9,  173,   45, 1652,  136,    6,   57, 1652,  516,  565,    1,   35,   68,  151, 1171,  736, 
    2,    9,   14,   37,    1,   62,  513,  755,   57, 1652,    1,   91,  253,   71,   15,   45,  655,   15,   57,  896,    1,   35,   68,   13, 
  318,  165,  894,    4,   23,    3,  136,  121,  298,  454,  103,   34,  145,   57,  211,  215,  443,  541,    2,  878,  503,  516,  565,  304, 
   31,  648,  208,   49,   22,   83,  117,  147,   64,  219,  246,   12,  152,   66,    1, 1290,  455,  164,  154,  234,   36,   12,    1, 1000, 
  316,  164,   15,  998,  812, 1289,  112,   36,   12,    1, 1426,  201,  119, 1078,  319,  512,   71,    8,  182,  124,  238,  230,  123,  901, 
    1,  184,  222,    6,   87,  435,   71,   60,   20,  211,  215,  443,  541,    2,    6,  170,    1,   16,   94,  294,  475,  419, 2450,    9, 
  571,   11,    1,   63,    8,    7,    5,    5,  122, 1080,   35,   68,   12,    4,  846,  337,   61,  301,  701,  297,   39,    6,  539,   27, 
  135,  979,    1,   35,  166,  181,   90,  143]), 'class': Tensor(shape=[1], dtype=Int64, value= [0])}
4.1.3 使用JSON文件

下面介绍,如何使用JSON文件来指定读取模型策略。

新建tfrecord_sample.json文件,在文件内写入如下内容:

numRows – 数据列数

columns – 依次为每列的列名、数据类型、数据维数、数据维度。

{
  "datasetType": "TF",
  "numRows": 2,
  "columns": {
    "input": {
      "type": "int64",
      "rank": 1,
      "shape": [800]
    },
    "class" : {
      "type": "int64",
      "rank": 1,
      "shape": [1]
    }
  }
}

有了相应的JSON文件,下面来介绍如何使用该文件进行数据读取。

修改load_tfrecord代码如下:

def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=tfrecord_json, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break

同时修改main部分代码如下:

load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=tfrecord_json)

代码解读

  • 这里直接将schema参数指定为JSON的文件路径

保存并再次运行文件load_tfrecord_dataset.py,输出内容如下:

{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18,  742,  330,  385,  999,  837,   56,  529, 1000,  260,    3,  171,   45,    7,   65,  136,  869,  211,  215,  443,  541, 
    3,   91, 1719,  636,    2,  424,  291,   16,   86,   31,   12,  211,  215,  443,  541,  999,  322,  128,  916,  102,  743,  136,  121,  298, 
  454,    2,  234,  225,    1,  136,  121,  298,  454,  100,   49,   22,  152,   70,  677,  806,   31, 1719,  636,  100,   25,  237,    2,  424, 
    1,   39,  100,   39,   71,  228,  385,  999,  837,    1,  171,   45,  136,  869,  211,  215,  443,  541,    1,   35,   68,   20,  149,  304, 
   31,   70,  677,    1,  106,    9,  308,  487,  869,  153,  597,    2,  523,  262,  184,  145,   57,   36,  158,   13,   69,   41,    1,   35, 
   68,  511,  402,  152,  469,   41,  617,  761,   50,   36,  144,  281,   26,  308,  487,  869,  153,  597,    4,   23,  208,   17,  121,  428, 
  646,   71,    8,   10,   47,   40,   87,   32,  413,  133,    9,  641,  159,   74,  144,  281,   26,  308,  487,  869,  153,  597,    2,  197, 
  447,    1,   91,  549,  202,  208,   17,  121,  558,  123,    2,  113,  203,    1,  419, 1024,  200,  154,   80,   16,  147,   64,  111,  208, 
  219,  136,   25,    6,  153,  597,    4,  160,  134,   16,    1,  167,  229, 1719,  636,  514,    2,    9,    7,   65,  321,  136,  869,  211, 
  215,  443,  541,    1,  514,    2,   69,   33,   13,   88,   80,   94,  294,    2,  308,  487,  869,  153,  597,    1,   39,   69,   33,  197, 
   57,  310,  335,   50,   94,  294,    2,  308,  487,  869,  153,  597,    1,  221,   13,   74,  337,   56,  499,  117,  836,  621,  488,   26, 
   94,  294,    1,   10,    5,    7,   10,   21,  973,  124,  492,   69,   33,  514,  218,  168,  117,    1,   82,  285,  148,  697,    2,  982, 
  298, 1535,  119,  743,  201, 1187,    4,    3,  136,  121,  298,  454,  103,  752,   31,   12, 1496,  762,  164,    2,  609,    6,  175,   83, 
  170,  257,  454,  963,    1,  149,   57,  136,  121,  298,  454,   62,   52,   87,  110,  257,   12,   34,   39,    2,  677, 1151,    1,  136, 
  121,  298,  454,  100,   49,   22,  138,   55,   39,    1,  752,  744,  184,   36,  169,   11,  561,    9,    1,   13,   74,   39,   62,    9, 
  308,  487,  869,  250,  321,    4,   23,  211,  215,  443,  541,    9,  641,   32,  900, 2586,   83, 1157,  165,  978,   97,  694,  837,  301, 
   22,   97,  694,  837,    9,  124,  492,    2,  720, 1341,    1,   35,   68,   32, 2294,  216,    2,    1,  106,    9,   13,    9,   20,   12, 
   25,   26,  973,  124,    1,   91,    9,   20,   12,  344,   36,    4,   23,    3,  167,  229,  211,  215,  443,  541,    2,  283,  683,    9, 
 1719,  636,    1,  292,  278,    9,  641,  103,   32,  283,  683,  976,  944,  511,  316,   30,  178,  223,  795,  136,  164,  301,   22,   25, 
  172,   26,   18, 1102,   69,   41,  136,  869,    1,   35,   68,  184,  344,  285,   74,    1,  178,  223,  795,  136,  164,   13,    9,    6, 
   26, 1152,  285,    1,  163,   20,   35,   68,  184,  344,  165,  894,   74,  521,   96,   39,    1,  976,  944,  511,  316,   62,    9,  167, 
  149,    1, 1024, 1405,  164,  271,  454,  102,  743,   62,    9,   25,  278,  100,    2,    4,   23,    3,  136,  121,  298,  454,  103,  752, 
   31,   12,  145,   57,  442,   32,  401,  665,   14,    2,  432,  848,  808,   49,   22,  432,  848,  808,   30,   35,   68,    2,  116,   39, 
   57,  896,    6,  237,    1,  112,    9,  508,  922,    2,   83,  479,    1,  106,   35,   13,  382,  203,   39,    9,  641,   32,   96,  168, 
   19,   59,  117,    1,   62,   13,  382,  203,   39,  351,   37,  309,  641,   32,  309,   51,    1,   35,   13,    9,  102,  406,  621,    4, 
   23,  136,  121,  298,  454,  103,  110,  177,    1,  145,   57,    2,  211,  215,  443,  541,    1, 1171,  736,    2,    9,   14,   37,   83, 
  170,    1,   22,   35,   68,    2,   14,   37,    9,  173,   45, 1652,  136,    6,   57, 1652,  516,  565,    1,   35,   68,  151, 1171,  736, 
    2,    9,   14,   37,    1,   62,  513,  755,   57, 1652,    1,   91,  253,   71,   15,   45,  655,   15,   57,  896,    1,   35,   68,   13, 
  318,  165,  894,    4,   23,    3,  136,  121,  298,  454,  103,   34,  145,   57,  211,  215,  443,  541,    2,  878,  503,  516,  565,  304, 
   31,  648,  208,   49,   22,   83,  117,  147,   64,  219,  246,   12,  152,   66,    1, 1290,  455,  164,  154,  234,   36,   12,    1, 1000, 
  316,  164,   15,  998,  812, 1289,  112,   36,   12,    1, 1426,  201,  119, 1078,  319,  512,   71,    8,  182,  124,  238,  230,  123,  901, 
    1,  184,  222,    6,   87,  435,   71,   60,   20,  211,  215,  443,  541,    2,    6,  170,    1,   16,   94,  294,  475,  419, 2450,    9, 
  571,   11,    1,   63,    8,    7,    5,    5,  122, 1080,   35,   68,   12,    4,  846,  337,   61,  301,  701,  297,   39,    6,  539,   27, 
  135,  979,    1,   35,  166,  181,   90,  143])}

4.2 columns_list使用

在某些场景下,我们可能只需要某(几)列的数据,而非全部数据,这时候就可以通过制定columns_list来进行数据加载。

下面我们只读取class列,来简单看看如何操作。

4.1.2基础上,修改load_tfrecord代码如下:

def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    data_schema = Schema()
    data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
    data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, columns_list=["class"], shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break

保存并再次运行文件load_tfrecord_dataset.py,输出内容如下:

可以看到只读取了我们指定的列,且数据加载正确。

{'class': Tensor(shape=[1], dtype=Int64, value= [0])}

5. 本文总结

本文介绍了在MindSpore中如何加载TFRecord数据集,并重点介绍了TFRecordDataset中的schemacolumns_list参数使用。

6. 本文参考

本文为原创文章,版权归作者所有,未经授权不得转载!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MindSpore,Transformer模型是基于自注意力机制的深度学习模型,在NLP领域被广泛应用。MindSpore提供了相关的API和方法来构建和训练Transformer模型。 首先,在MindSpore,可以使用`EmbeddingLookup`类来定词嵌入层,该层负责将输入的单词转换为向量表示。这个类在`transformer_model.py`文件进行定。 接下来,为了进行网络的反向传播和训练,可以使用MindSpore的高级接口`MindSpore.Model`来定网络反向和进行训练。在`transformer_model.py`文件,可以看到网络的反向定以及使用`MindSpore.Model`进行训练的示例代码。首先,通过`TransformerTrainOneStepCell`将网络和损失函数组合在一起,然后使用该组合后的网络进行训练。 最后,通过调用`model.train`方法,可以使用定好的模型、数据集和优化器进行训练。需要指定训练的轮数、数据集、回调函数等参数来完成训练过程。 综上所述,在MindSpore,可以使用相关的API和方法来构建和训练Transformer模型。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [【MindSpore易点通】Transformer的注意力机制](https://blog.csdn.net/Kenji_Shinji/article/details/127958722)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [基于MindSpore的Transformer网络实现](https://blog.csdn.net/Kenji_Shinji/article/details/127314332)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值