
Dive Into MindSpore – TFRecordDataset For Dataset Load



  • Ubuntu 20.04
  • Python 3.8
  • MindSpore 1.7.0


  • 背景介绍
  • 先看文档
  • 生成TFRecord
  • 数据加载
  • 本文总结
  • 本文参考

1. 背景介绍


TFRecord 格式是一种用于存储二进制记录序列的简单格式,该格式能够更好的利用内存,内部包含多个tf.train.Example,在一个Examples消息体中包含一系列的tf.train.feature属性,而每一个feature是一个key-value的键值对,其中key是string类型,value的取值有三种:

  • bytes_list:可以存储stringbyte两种数据类型
  • float_list:可以存储float(float32)double(float64)两种数据类型
  • int64_list:可以存储bool, enum, int32, uint32, int64, uint64数据类型


2. 先看文档



  • dataset_files – 数据集文件路径。
  • schema – 读取模式策略,通俗来说就是要读取的tfrecord文件内的数据内容格式。可以通过json或者Schema传入。默认为None不指定。
  • columns_list – 指定读取的具体数据列。默认全部读取。
  • num_samples – 指定读取出来的样本数量。
  • shuffle – 是否对数据进行打乱,可参考之前的文章解读。

3. 生成TFRecord





import codecs
import os
import re
import six
import tensorflow as tf

from collections import Counter

def _int64_feature(values):
    """Returns a TF-Feature of int64s.

        values: A scalar or list of values.

        A TF-Feature.
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def _float32_feature(values):
    """Returns a TF-Feature of float32s.

        values: A scalar or list of values.

        A TF-Feature.
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(float_list=tf.train.FloatList(value=values))

def _bytes_feature(values):
    """Returns a TF-Feature of bytes.
        values: A scalar or list of values.

        A TF-Feature
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))

def convert_to_feature(values):
    """Convert to TF-Feature based on the type of element in values.

        values: A scalar or list of values.

        A TF-Feature.
    if not isinstance(values, (tuple, list)):
        values = [values]

    if isinstance(values[0], int):
        return _int64_feature(values)
    elif isinstance(values[0], float):
        return _float32_feature(values)
    elif isinstance(values[0], bytes):
        return _bytes_feature(values)
        raise ValueError("feature type {0} is not supported now !".format(type(values[0])))

def dict_to_example(dictionary):
    """Converts a dictionary of string->int to a tf.Example."""
    features = {}
    for k, v in six.iteritems(dictionary):
        features[k] = convert_to_feature(values=v)
    return tf.train.Example(features=tf.train.Features(feature=features))

def get_txt_files(data_dir):
    cls_txt_dict = {}
    txt_file_list = []

    # get files list and class files list.
    sub_data_name_list = next(os.walk(data_dir))[1]
    sub_data_name_list = sorted(sub_data_name_list)
    for sub_data_name in sub_data_name_list:
        sub_data_dir = os.path.join(data_dir, sub_data_name)
        data_name_list = next(os.walk(sub_data_dir))[2]
        data_file_list = [os.path.join(sub_data_dir, data_name) for data_name in data_name_list]
        cls_txt_dict[sub_data_name] = data_file_list
        num_data_files = len(data_file_list)
        print("{}: {}".format(sub_data_name, num_data_files), flush=True)
    num_txt_files = len(txt_file_list)
    print("total: {}".format(num_txt_files), flush=True)

    return cls_txt_dict, txt_file_list

def get_txt_data(txt_file):
    with codecs.open(txt_file, "r", "UTF8") as fp:
        txt_content = fp.read()
    txt_data = re.sub("\s+", " ", txt_content)

    return txt_data

def build_vocab(txt_file_list, vocab_size=7000):
    counter = Counter()
    for txt_file in txt_file_list:
        txt_data = get_txt_data(txt_file)

    num_vocab = len(counter)
    if num_vocab < vocab_size - 1:
        real_vocab_size = num_vocab + 2
        real_vocab_size = vocab_size

    # pad_id is 0, unk_id is 1
    vocab_dict = {word_freq[0]: ix + 1 for ix, word_freq in enumerate(counter.most_common(real_vocab_size - 2))}

    print("real vocab size: {}".format(real_vocab_size), flush=True)
    print("vocab dict:\n{}".format(vocab_dict), flush=True)

    return vocab_dict

def make_tfrecords(
        data_dir, tfrecord_dir, vocab_size=7000, min_seq_length=10, max_seq_length=800,
        num_train=8, num_test=2, start_fid=0):
    # get txt files
    cls_txt_dict, txt_file_list = get_txt_files(data_dir=data_dir)
    # map word to id
    vocab_dict = build_vocab(txt_file_list=txt_file_list, vocab_size=vocab_size)
    # map class to id
    class_dict = {class_name: ix for ix, class_name in enumerate(cls_txt_dict.keys())}

    train_writers = []
    for fid in range(start_fid, num_train+start_fid):
        tfrecord_file = os.path.join(tfrecord_dir, "train_{:04d}.tfrecord".format(fid))
        writer = tf.io.TFRecordWriter(tfrecord_file)

    test_writers = []
    for fid in range(start_fid, num_test+start_fid):
        tfrecord_file = os.path.join(tfrecord_dir, "test_{:04d}.tfrecord".format(fid))
        writer = tf.io.TFRecordWriter(tfrecord_file)

    pad_id = 0
    unk_id = 1
    num_samples = 0
    num_train_samples = 0
    num_test_samples = 0
    for class_name, class_file_list in cls_txt_dict.items():
        class_id = class_dict[class_name]
        num_class_pass = 0
        for txt_file in class_file_list:
            txt_data = get_txt_data(txt_file=txt_file)
            txt_len = len(txt_data)
            if txt_len < min_seq_length:
                num_class_pass += 1
            if txt_len > max_seq_length:
                txt_data = txt_data[:max_seq_length]
                txt_len = max_seq_length
            word_ids = []
            for word in txt_data:
                word_id = vocab_dict.get(word, unk_id)
            for _ in range(max_seq_length - txt_len):

            example = dict_to_example({"input": word_ids, "class": class_id})
            num_samples += 1
            if num_samples % 10 == 0:
                num_test_samples += 1
                writer_id = num_test_samples % num_test
                num_train_samples += 1
                writer_id = num_train_samples % num_train
        print("{} pass: {}".format(class_name, num_class_pass), flush=True)

    for writer in train_writers:
    for writer in test_writers:

    print("num samples: {}".format(num_samples), flush=True)
    print("num train samples: {}".format(num_train_samples), flush=True)
    print("num test samples: {}".format(num_test_samples), flush=True)

def main():
    data_dir = "{your_data_dir}"
    tfrecord_dir = "{your_tfrecord_dir}"
    make_tfrecords(data_dir=data_dir, tfrecord_dir=tfrecord_dir)

if __name__ == "__main__":



python3 make_tfrecord.py


├── test_0000.tfrecord
├── test_0001.tfrecord
├── train_0000.tfrecord
├── train_0001.tfrecord
├── train_0002.tfrecord
├── train_0003.tfrecord
├── train_0004.tfrecord
├── train_0005.tfrecord
├── train_0006.tfrecord
└── train_0007.tfrecord

0 directories, 10 files

4. 数据加载


4.1 schema使用

4.1.1 不指定schema



import os

from mindspore.common import dtype as mstype
from mindspore.dataset import Schema
from mindspore.dataset import TFRecordDataset

def get_tfrecord_files(tfrecord_dir, file_suffix="tfrecord", is_train=True):
    if not os.path.exists(tfrecord_dir):
        raise ValueError("tfrecord directory: {} not exists!".format(tfrecord_dir))

    if is_train:
        file_prefix = "train"
        file_prefix = "test"

    data_sources = []
    for parent, _, filenames in os.walk(tfrecord_dir):
        for filename in filenames:
            if not filename.startswith(file_prefix):
            tmp_path = os.path.join(parent, filename)
            if tmp_path.endswith(file_suffix):
    return data_sources

def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    dataset = TFRecordDataset(dataset_files=tfrecord_files, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)

def main():
    tfrecord_dir = "{your_tfrecord_dir}"
    tfrecord_json = "{your_tfrecord_json_file}"
    load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=None)

if __name__ == "__main__":


  • get_tfrecord_files – 获取指定的TFRecord文件列表
  • load_tfrecord – 数据集加载


python3 load_tfrecord_dataset.py



{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18,  742,  330,  385,  999,  837,   56,  529, 1000,  260,    3,  171,   45,    7,   65,  136,  869,  211,  215,  443,  541,
    3,   91, 1719,  636,    2,  424,  291,   16,   86,   31,   12,  211,  215,  443,  541,  999,  322,  128,  916,  102,  743,  136,  121,  298,
  454,    2,  234,  225,    1,  136,  121,  298,  454,  100,   49,   22,  152,   70,  677,  806,   31, 1719,  636,  100,   25,  237,    2,  424,
    1,   39,  100,   39,   71,  228,  385,  999,  837,    1,  171,   45,  136,  869,  211,  215,  443,  541,    1,   35,   68,   20,  149,  304,
   31,   70,  677,    1,  106,    9,  308,  487,  869,  153,  597,    2,  523,  262,  184,  145,   57,   36,  158,   13,   69,   41,    1,   35,
   68,  511,  402,  152,  469,   41,  617,  761,   50,   36,  144,  281,   26,  308,  487,  869,  153,  597,    4,   23,  208,   17,  121,  428,
  646,   71,    8,   10,   47,   40,   87,   32,  413,  133,    9,  641,  159,   74,  144,  281,   26,  308,  487,  869,  153,  597,    2,  197,
  447,    1,   91,  549,  202,  208,   17,  121,  558,  123,    2,  113,  203,    1,  419, 1024,  200,  154,   80,   16,  147,   64,  111,  208,
  219,  136,   25,    6,  153,  597,    4,  160,  134,   16,    1,  167,  229, 1719,  636,  514,    2,    9,    7,   65,  321,  136,  869,  211,
  215,  443,  541,    1,  514,    2,   69,   33,   13,   88,   80,   94,  294,    2,  308,  487,  869,  153,  597,    1,   39,   69,   33,  197,
   57,  310,  335,   50,   94,  294,    2,  308,  487,  869,  153,  597,    1,  221,   13,   74,  337,   56,  499,  117,  836,  621,  488,   26,
   94,  294,    1,   10,    5,    7,   10,   21,  973,  124,  492,   69,   33,  514,  218,  168,  117,    1,   82,  285,  148,  697,    2,  982,
  298, 1535,  119,  743,  201, 1187,    4,    3,  136,  121,  298,  454,  103,  752,   31,   12, 1496,  762,  164,    2,  609,    6,  175,   83,
  170,  257,  454,  963,    1,  149,   57,  136,  121,  298,  454,   62,   52,   87,  110,  257,   12,   34,   39,    2,  677, 1151,    1,  136,
  121,  298,  454,  100,   49,   22,  138,   55,   39,    1,  752,  744,  184,   36,  169,   11,  561,    9,    1,   13,   74,   39,   62,    9,
  308,  487,  869,  250,  321,    4,   23,  211,  215,  443,  541,    9,  641,   32,  900, 2586,   83, 1157,  165,  978,   97,  694,  837,  301,
   22,   97,  694,  837,    9,  124,  492,    2,  720, 1341,    1,   35,   68,   32, 2294,  216,    2,    1,  106,    9,   13,    9,   20,   12,
   25,   26,  973,  124,    1,   91,    9,   20,   12,  344,   36,    4,   23,    3,  167,  229,  211,  215,  443,  541,    2,  283,  683,    9,
 1719,  636,    1,  292,  278,    9,  641,  103,   32,  283,  683,  976,  944,  511,  316,   30,  178,  223,  795,  136,  164,  301,   22,   25,
  172,   26,   18, 1102,   69,   41,  136,  869,    1,   35,   68,  184,  344,  285,   74,    1,  178,  223,  795,  136,  164,   13,    9,    6,
   26, 1152,  285,    1,  163,   20,   35,   68,  184,  344,  165,  894,   74,  521,   96,   39,    1,  976,  944,  511,  316,   62,    9,  167,
  149,    1, 1024, 1405,  164,  271,  454,  102,  743,   62,    9,   25,  278,  100,    2,    4,   23,    3,  136,  121,  298,  454,  103,  752,
   31,   12,  145,   57,  442,   32,  401,  665,   14,    2,  432,  848,  808,   49,   22,  432,  848,  808,   30,   35,   68,    2,  116,   39,
   57,  896,    6,  237,    1,  112,    9,  508,  922,    2,   83,  479,    1,  106,   35,   13,  382,  203,   39,    9,  641,   32,   96,  168,
   19,   59,  117,    1,   62,   13,  382,  203,   39,  351,   37,  309,  641,   32,  309,   51,    1,   35,   13,    9,  102,  406,  621,    4,
   23,  136,  121,  298,  454,  103,  110,  177,    1,  145,   57,    2,  211,  215,  443,  541,    1, 1171,  736,    2,    9,   14,   37,   83,
  170,    1,   22,   35,   68,    2,   14,   37,    9,  173,   45, 1652,  136,    6,   57, 1652,  516,  565,    1,   35,   68,  151, 1171,  736,
    2,    9,   14,   37,    1,   62,  513,  755,   57, 1652,    1,   91,  253,   71,   15,   45,  655,   15,   57,  896,    1,   35,   68,   13,
  318,  165,  894,    4,   23,    3,  136,  121,  298,  454,  103,   34,  145,   57,  211,  215,  443,  541,    2,  878,  503,  516,  565,  304,
   31,  648,  208,   49,   22,   83,  117,  147,   64,  219,  246,   12,  152,   66,    1, 1290,  455,  164,  154,  234,   36,   12,    1, 1000,
  316,  164,   15,  998,  812, 1289,  112,   36,   12,    1, 1426,  201,  119, 1078,  319,  512,   71,    8,  182,  124,  238,  230,  123,  901,
    1,  184,  222,    6,   87,  435,   71,   60,   20,  211,  215,  443,  541,    2,    6,  170,    1,   16,   94,  294,  475,  419, 2450,    9,
  571,   11,    1,   63,    8,    7,    5,    5,  122, 1080,   35,   68,   12,    4,  846,  337,   61,  301,  701,  297,   39,    6,  539,   27,
  135,  979,    1,   35,  166,  181,   90,  143])}
4.1.2 使用Schema对象



def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    data_schema = Schema()
    data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
    data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)


  • 这里使用了Schema对象,并且指定了列名,列的数据类型和数据维度。



{'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18,  742,  330,  385,  999,  837,   56,  529, 1000,  260,    3,  171,   45,    7,   65,  136,  869,  211,  215,  443,  541, 
    3,   91, 1719,  636,    2,  424,  291,   16,   86,   31,   12,  211,  215,  443,  541,  999,  322,  128,  916,  102,  743,  136,  121,  298, 
  454,    2,  234,  225,    1,  136,  121,  298,  454,  100,   49,   22,  152,   70,  677,  806,   31, 1719,  636,  100,   25,  237,    2,  424, 
    1,   39,  100,   39,   71,  228,  385,  999,  837,    1,  171,   45,  136,  869,  211,  215,  443,  541,    1,   35,   68,   20,  149,  304, 
   31,   70,  677,    1,  106,    9,  308,  487,  869,  153,  597,    2,  523,  262,  184,  145,   57,   36,  158,   13,   69,   41,    1,   35, 
   68,  511,  402,  152,  469,   41,  617,  761,   50,   36,  144,  281,   26,  308,  487,  869,  153,  597,    4,   23,  208,   17,  121,  428, 
  646,   71,    8,   10,   47,   40,   87,   32,  413,  133,    9,  641,  159,   74,  144,  281,   26,  308,  487,  869,  153,  597,    2,  197, 
  447,    1,   91,  549,  202,  208,   17,  121,  558,  123,    2,  113,  203,    1,  419, 1024,  200,  154,   80,   16,  147,   64,  111,  208, 
  219,  136,   25,    6,  153,  597,    4,  160,  134,   16,    1,  167,  229, 1719,  636,  514,    2,    9,    7,   65,  321,  136,  869,  211, 
  215,  443,  541,    1,  514,    2,   69,   33,   13,   88,   80,   94,  294,    2,  308,  487,  869,  153,  597,    1,   39,   69,   33,  197, 
   57,  310,  335,   50,   94,  294,    2,  308,  487,  869,  153,  597,    1,  221,   13,   74,  337,   56,  499,  117,  836,  621,  488,   26, 
   94,  294,    1,   10,    5,    7,   10,   21,  973,  124,  492,   69,   33,  514,  218,  168,  117,    1,   82,  285,  148,  697,    2,  982, 
  298, 1535,  119,  743,  201, 1187,    4,    3,  136,  121,  298,  454,  103,  752,   31,   12, 1496,  762,  164,    2,  609,    6,  175,   83, 
  170,  257,  454,  963,    1,  149,   57,  136,  121,  298,  454,   62,   52,   87,  110,  257,   12,   34,   39,    2,  677, 1151,    1,  136, 
  121,  298,  454,  100,   49,   22,  138,   55,   39,    1,  752,  744,  184,   36,  169,   11,  561,    9,    1,   13,   74,   39,   62,    9, 
  308,  487,  869,  250,  321,    4,   23,  211,  215,  443,  541,    9,  641,   32,  900, 2586,   83, 1157,  165,  978,   97,  694,  837,  301, 
   22,   97,  694,  837,    9,  124,  492,    2,  720, 1341,    1,   35,   68,   32, 2294,  216,    2,    1,  106,    9,   13,    9,   20,   12, 
   25,   26,  973,  124,    1,   91,    9,   20,   12,  344,   36,    4,   23,    3,  167,  229,  211,  215,  443,  541,    2,  283,  683,    9, 
 1719,  636,    1,  292,  278,    9,  641,  103,   32,  283,  683,  976,  944,  511,  316,   30,  178,  223,  795,  136,  164,  301,   22,   25, 
  172,   26,   18, 1102,   69,   41,  136,  869,    1,   35,   68,  184,  344,  285,   74,    1,  178,  223,  795,  136,  164,   13,    9,    6, 
   26, 1152,  285,    1,  163,   20,   35,   68,  184,  344,  165,  894,   74,  521,   96,   39,    1,  976,  944,  511,  316,   62,    9,  167, 
  149,    1, 1024, 1405,  164,  271,  454,  102,  743,   62,    9,   25,  278,  100,    2,    4,   23,    3,  136,  121,  298,  454,  103,  752, 
   31,   12,  145,   57,  442,   32,  401,  665,   14,    2,  432,  848,  808,   49,   22,  432,  848,  808,   30,   35,   68,    2,  116,   39, 
   57,  896,    6,  237,    1,  112,    9,  508,  922,    2,   83,  479,    1,  106,   35,   13,  382,  203,   39,    9,  641,   32,   96,  168, 
   19,   59,  117,    1,   62,   13,  382,  203,   39,  351,   37,  309,  641,   32,  309,   51,    1,   35,   13,    9,  102,  406,  621,    4, 
   23,  136,  121,  298,  454,  103,  110,  177,    1,  145,   57,    2,  211,  215,  443,  541,    1, 1171,  736,    2,    9,   14,   37,   83, 
  170,    1,   22,   35,   68,    2,   14,   37,    9,  173,   45, 1652,  136,    6,   57, 1652,  516,  565,    1,   35,   68,  151, 1171,  736, 
    2,    9,   14,   37,    1,   62,  513,  755,   57, 1652,    1,   91,  253,   71,   15,   45,  655,   15,   57,  896,    1,   35,   68,   13, 
  318,  165,  894,    4,   23,    3,  136,  121,  298,  454,  103,   34,  145,   57,  211,  215,  443,  541,    2,  878,  503,  516,  565,  304, 
   31,  648,  208,   49,   22,   83,  117,  147,   64,  219,  246,   12,  152,   66,    1, 1290,  455,  164,  154,  234,   36,   12,    1, 1000, 
  316,  164,   15,  998,  812, 1289,  112,   36,   12,    1, 1426,  201,  119, 1078,  319,  512,   71,    8,  182,  124,  238,  230,  123,  901, 
    1,  184,  222,    6,   87,  435,   71,   60,   20,  211,  215,  443,  541,    2,    6,  170,    1,   16,   94,  294,  475,  419, 2450,    9, 
  571,   11,    1,   63,    8,    7,    5,    5,  122, 1080,   35,   68,   12,    4,  846,  337,   61,  301,  701,  297,   39,    6,  539,   27, 
  135,  979,    1,   35,  166,  181,   90,  143]), 'class': Tensor(shape=[1], dtype=Int64, value= [0])}
4.1.3 使用JSON文件



numRows – 数据列数

columns – 依次为每列的列名、数据类型、数据维数、数据维度。

  "datasetType": "TF",
  "numRows": 2,
  "columns": {
    "input": {
      "type": "int64",
      "rank": 1,
      "shape": [800]
    "class" : {
      "type": "int64",
      "rank": 1,
      "shape": [1]



def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=tfrecord_json, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)


load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=tfrecord_json)


  • 这里直接将schema参数指定为JSON的文件路径


{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18,  742,  330,  385,  999,  837,   56,  529, 1000,  260,    3,  171,   45,    7,   65,  136,  869,  211,  215,  443,  541, 
    3,   91, 1719,  636,    2,  424,  291,   16,   86,   31,   12,  211,  215,  443,  541,  999,  322,  128,  916,  102,  743,  136,  121,  298, 
  454,    2,  234,  225,    1,  136,  121,  298,  454,  100,   49,   22,  152,   70,  677,  806,   31, 1719,  636,  100,   25,  237,    2,  424, 
    1,   39,  100,   39,   71,  228,  385,  999,  837,    1,  171,   45,  136,  869,  211,  215,  443,  541,    1,   35,   68,   20,  149,  304, 
   31,   70,  677,    1,  106,    9,  308,  487,  869,  153,  597,    2,  523,  262,  184,  145,   57,   36,  158,   13,   69,   41,    1,   35, 
   68,  511,  402,  152,  469,   41,  617,  761,   50,   36,  144,  281,   26,  308,  487,  869,  153,  597,    4,   23,  208,   17,  121,  428, 
  646,   71,    8,   10,   47,   40,   87,   32,  413,  133,    9,  641,  159,   74,  144,  281,   26,  308,  487,  869,  153,  597,    2,  197, 
  447,    1,   91,  549,  202,  208,   17,  121,  558,  123,    2,  113,  203,    1,  419, 1024,  200,  154,   80,   16,  147,   64,  111,  208, 
  219,  136,   25,    6,  153,  597,    4,  160,  134,   16,    1,  167,  229, 1719,  636,  514,    2,    9,    7,   65,  321,  136,  869,  211, 
  215,  443,  541,    1,  514,    2,   69,   33,   13,   88,   80,   94,  294,    2,  308,  487,  869,  153,  597,    1,   39,   69,   33,  197, 
   57,  310,  335,   50,   94,  294,    2,  308,  487,  869,  153,  597,    1,  221,   13,   74,  337,   56,  499,  117,  836,  621,  488,   26, 
   94,  294,    1,   10,    5,    7,   10,   21,  973,  124,  492,   69,   33,  514,  218,  168,  117,    1,   82,  285,  148,  697,    2,  982, 
  298, 1535,  119,  743,  201, 1187,    4,    3,  136,  121,  298,  454,  103,  752,   31,   12, 1496,  762,  164,    2,  609,    6,  175,   83, 
  170,  257,  454,  963,    1,  149,   57,  136,  121,  298,  454,   62,   52,   87,  110,  257,   12,   34,   39,    2,  677, 1151,    1,  136, 
  121,  298,  454,  100,   49,   22,  138,   55,   39,    1,  752,  744,  184,   36,  169,   11,  561,    9,    1,   13,   74,   39,   62,    9, 
  308,  487,  869,  250,  321,    4,   23,  211,  215,  443,  541,    9,  641,   32,  900, 2586,   83, 1157,  165,  978,   97,  694,  837,  301, 
   22,   97,  694,  837,    9,  124,  492,    2,  720, 1341,    1,   35,   68,   32, 2294,  216,    2,    1,  106,    9,   13,    9,   20,   12, 
   25,   26,  973,  124,    1,   91,    9,   20,   12,  344,   36,    4,   23,    3,  167,  229,  211,  215,  443,  541,    2,  283,  683,    9, 
 1719,  636,    1,  292,  278,    9,  641,  103,   32,  283,  683,  976,  944,  511,  316,   30,  178,  223,  795,  136,  164,  301,   22,   25, 
  172,   26,   18, 1102,   69,   41,  136,  869,    1,   35,   68,  184,  344,  285,   74,    1,  178,  223,  795,  136,  164,   13,    9,    6, 
   26, 1152,  285,    1,  163,   20,   35,   68,  184,  344,  165,  894,   74,  521,   96,   39,    1,  976,  944,  511,  316,   62,    9,  167, 
  149,    1, 1024, 1405,  164,  271,  454,  102,  743,   62,    9,   25,  278,  100,    2,    4,   23,    3,  136,  121,  298,  454,  103,  752, 
   31,   12,  145,   57,  442,   32,  401,  665,   14,    2,  432,  848,  808,   49,   22,  432,  848,  808,   30,   35,   68,    2,  116,   39, 
   57,  896,    6,  237,    1,  112,    9,  508,  922,    2,   83,  479,    1,  106,   35,   13,  382,  203,   39,    9,  641,   32,   96,  168, 
   19,   59,  117,    1,   62,   13,  382,  203,   39,  351,   37,  309,  641,   32,  309,   51,    1,   35,   13,    9,  102,  406,  621,    4, 
   23,  136,  121,  298,  454,  103,  110,  177,    1,  145,   57,    2,  211,  215,  443,  541,    1, 1171,  736,    2,    9,   14,   37,   83, 
  170,    1,   22,   35,   68,    2,   14,   37,    9,  173,   45, 1652,  136,    6,   57, 1652,  516,  565,    1,   35,   68,  151, 1171,  736, 
    2,    9,   14,   37,    1,   62,  513,  755,   57, 1652,    1,   91,  253,   71,   15,   45,  655,   15,   57,  896,    1,   35,   68,   13, 
  318,  165,  894,    4,   23,    3,  136,  121,  298,  454,  103,   34,  145,   57,  211,  215,  443,  541,    2,  878,  503,  516,  565,  304, 
   31,  648,  208,   49,   22,   83,  117,  147,   64,  219,  246,   12,  152,   66,    1, 1290,  455,  164,  154,  234,   36,   12,    1, 1000, 
  316,  164,   15,  998,  812, 1289,  112,   36,   12,    1, 1426,  201,  119, 1078,  319,  512,   71,    8,  182,  124,  238,  230,  123,  901, 
    1,  184,  222,    6,   87,  435,   71,   60,   20,  211,  215,  443,  541,    2,    6,  170,    1,   16,   94,  294,  475,  419, 2450,    9, 
  571,   11,    1,   63,    8,    7,    5,    5,  122, 1080,   35,   68,   12,    4,  846,  337,   61,  301,  701,  297,   39,    6,  539,   27, 
  135,  979,    1,   35,  166,  181,   90,  143])}

4.2 columns_list使用




def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    data_schema = Schema()
    data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
    data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, columns_list=["class"], shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)



{'class': Tensor(shape=[1], dtype=Int64, value= [0])}

5. 本文总结


6. 本文参考


