- 看一下这个工程中的数据加载方式
数据加载
1 Dataset 类
examples/timit/data/load_dataset_ctc.py
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""Load dataset for the CTC model (TIMIT corpus).
In addition, frame stacking and skipping are used.
You can use only the single GPU version.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from os.path import join, isfile
import pickle
import numpy as np
# 首先从这里引入了DatasetBase这个类
from utils.dataset.ctc import DatasetBase
# Dataset 继承了 DatasetBase这个类,再去看DatasetBase 这个类可以看到继承了Base这个类
class Dataset(DatasetBase):
def __init__(self, data_type, label_type, batch_size,
max_epoch=None, splice=1,
num_stack=1, num_skip=1,
shuffle=False, sort_utt=False, sort_stop_epoch=None,
progressbar=False):
"""A class for loading dataset.
Args:
data_type (string): train or dev or test
label_type (string): phone39 or phone48 or phone61 or
character or character_capital_divide
batch_size (int): the size of mini-batch
max_epoch (int, optional): the max epoch. None means infinite loop.
splice (int, optional): frames to splice. Default is 1 frame.
num_stack (int, optional): the number of frames to stack
num_skip (int, optional): the number of frames to skip
shuffle (bool, optional): if True, shuffle utterances. This is
disabled when sort_utt is True.
sort_utt (bool, optional): if True, sort all utterances by the
number of frames and utteraces in each mini-batch are shuffled.
Otherwise, shuffle utteraces.
sort_stop_epoch (int, optional): After sort_stop_epoch, training
will revert back to a random order
progressbar (bool, optional): if True, visualize progressbar
"""
# 这里先调用了父类的构造函数,这里没有参数
super(Dataset, self).__init__()
self.is_test = True if data_type == 'test' else False
self.data_type = data_type
self.label_type = label_type
self.batch_size = batch_size
self.max_epoch = max_epoch
self.splice = splice
self.num_stack = num_stack
self.num_skip = num_skip
self.shuffle = shuffle
self.sort_utt = sort_utt
self.sort_stop_epoch = sort_stop_epoch
self.progressbar = progressbar
self.num_gpu = 1
# paths where datasets exist
# 设置dataset路径
dataset_root = ['/data/inaguma/timit',
'/n/sd8/inaguma/corpus/timit/dataset']
input_path = join(dataset_root[0], 'inputs', data_type)
# NOTE: ex.) save_path: timit_dataset_path/inputs/data_type/***.npy
label_path = join(dataset_root[0], 'labels', data_type, label_type)
# NOTE: ex.) save_path:
# timit_dataset_path/labels/data_type/label_type/***.npy
# Load the frame number dictionary
# 加载帧数字典,看后面的代码可以知道字典键名是文件名,键值是帧数
# 加载的是 .pickle文件, pickle存储的是结构化文件,如用pickle.dump()存储一个字典,再pickle.load()出来就是字典,而不是文本
# 这里的if-else 就是为了判断刚才两个目录哪个下面有这个文件
if isfile(join(input_path, 'frame_num.pickle')):
with open(join(input_path, 'frame_num.pickle'), 'rb') as f:
self.frame_num_dict = pickle.load(f)
else:
dataset_root.pop(0)
input_path = join(dataset_root[0], 'inputs', data_type)