语音识别(二)—数据处理

本节目标

2.1 wav和label文件对应
2.2 wav经过特征提取转换成频谱图 每个batch_size的图维度一致
2.3 label中所有文字建立vocab 并将label中文字转换为vocab中的index 即token_2_index
2.4 建立Dataset子类,并按照batch_size建立dataloader

2.1建立wav与label文件列表

def source_get(source_file):
    train_file = source_file + '/data'
    label_lst = []
    wav_lst = []
    for root, dirs, files in os.walk(train_file):
        for file in files:
            if file.endswith('.wav') or file.endswith('.WAV'):
                wav_file = os.sep.join([root, file])
                label_file = wav_file + '.trn'
                wav_lst.append(wav_file)
                label_lst.append(label_file)
            
    return label_lst, wav_lst
source_file = '/home/hyx/文档/data/data_thchs30'
label_lst, wav_lst = source_get(source_file)

2.2 train_data & pad

# Feature Padding Function 
# Parameters
#     - x          : list, list of np.array
#     - pad_len    : int, length to pad (0 for max_len in x)      
# Return
#     - new_x      : np.array with shape (len(x),pad_len,dim of feature)
def zero_padding(x,pad_len):
    features = x[0].shape[-1]
    if pad_len is 0: pad_len = max([len(v) for v in x])
    new_x = np.zeros((len(x),pad_len,features))
    for idx,ins in enumerate(x):
        new_x[idx,:min(len(ins),pad_len),:] = ins[:min(len(ins),pad_len),:]
    return new_x

2.3 label_data & pad

# Target Padding Function 
# Parameters
#     - y          : list, list of int
#     - max_len    : int, max length of output (0 for max_len in y)     
# Return
#     - new_y      : np.array with shape (len(y),max_len)
def target_padding(y,max_len):
    if max_len is 0: max_len = max([len(v) for v in y])
    new_y = np.zeros((len(y),max_len),dtype=int)
    for idx,label_seq in enumerate(y):
        new_y[idx,:len(label_seq)] = np.array(label_seq)
    return new_y

2.4 token2idx & vocab

def encode_target(input_list,table=None,mode='subword',max_idx=500):
    if table is None:
        ### Step 1. Calculate wrd frequency
        table = {}
        for target in input_list:
            for t in target:
                if t not in table:
                    table[t] = 1
                else:
                    table[t] += 1
        ### Step 2. Top k list for encode map
        max_idx = min(max_idx-3,len(table))
        all_tokens = [k for k,v in sorted(table.items(), key = itemgetter(1), reverse = True)][:max_idx]
        table = {'<sos>':0,'<eos>':1}
        if mode == "word": table['<unk>']=2
        for tok in all_tokens:
            table[tok] = len(table)
    ### Step 3. Encode
    output_list = []
    for target in input_list:
        tmp = [0]
        for t in target:
            if t in table:
                tmp.append(table[t])
            else:
                if mode == "word":
                    tmp.append(2)
                else:
                    tmp.append(table['<unk>'])
                    # raise ValueError('OOV error: '+t)
        tmp.append(1)
        output_list.append(tmp)
    return output_list,table

2.5 Dataset

class LibriDataset(Dataset):
    def __init__(self, file_path, sets, bucket_size, max_timestep=0, max_label_len=0,drop=False,text_only=False):
        # Read file
        self.root = file_path
        tables = [pd.read_csv(os.path.join(file_path,s+'.csv')) for s in sets]
        self.table = pd.concat(tables,ignore_index=True).sort_values(by=['length'],ascending=False)
        self.text_only = text_only

        # Crop seqs that are too long
        if drop and max_timestep >0 and not text_only:
            self.table = self.table[self.table.length < max_timestep]
        if drop and max_label_len >0:
            self.table = self.table[self.table.label.str.count('_')+1 < max_label_len]

        X = self.table['file_path'].tolist()
        X_lens = self.table['length'].tolist()
            
        Y = [list(map(int, label.split('_'))) for label in self.table['label'].tolist()]
        if text_only:
            Y.sort(key=len,reverse=True)

        # Bucketing, X & X_len is dummy when text_only==True
        self.X = []
        self.Y = []
        tmp_x,tmp_len,tmp_y = [],[],[]

        for x,x_len,y in zip(X,X_lens,Y):
            tmp_x.append(x)
            tmp_len.append(x_len)
            tmp_y.append(y)
            # Half  the batch size if seq too long
            if len(tmp_x)== bucket_size:
                if (bucket_size>=2) and ((max(tmp_len)> HALF_BATCHSIZE_TIME) or (max([len(y) for y in tmp_y])>HALF_BATCHSIZE_LABEL)):
                    self.X.append(tmp_x[:bucket_size//2])
                    self.X.append(tmp_x[bucket_size//2:])
                    self.Y.append(tmp_y[:bucket_size//2])
                    self.Y.append(tmp_y[bucket_size//2:])
                else:
                    self.X.append(tmp_x)
                    self.Y.append(tmp_y)
                tmp_x,tmp_len,tmp_y = [],[],[]
        if len(tmp_x)>0:
            self.X.append(tmp_x)
            self.Y.append(tmp_y)


    def __getitem__(self, index):
        # Load label
        y = [y for y in self.Y[index]]
        y = target_padding(y, max([len(v) for v in y]))
        if self.text_only:
            return y
        
        # Load acoustic feature and pad
        x = [torch.FloatTensor(np.load(os.path.join(self.root,f))) for f in self.X[index]]
        x = pad_sequence(x, batch_first=True)
        return x,y
            
    
    def __len__(self):
        return len(self.Y)

2.6 Dataloader

def LoadDataset(split, text_only, data_path, batch_size, max_timestep, max_label_len, use_gpu, n_jobs,
                dataset, train_set, dev_set, test_set, dev_batch_size, decode_beam_size,**kwargs):
    if split=='train':
        bs = batch_size
        shuffle = True
        sets = train_set
        drop_too_long = True
    elif split=='dev':
        bs = dev_batch_size
        shuffle = False
        sets = dev_set
        drop_too_long = True
    elif split=='test':
        bs = 1 if decode_beam_size>1 else dev_batch_size
        n_jobs = 1
        shuffle = False
        sets = test_set
        drop_too_long = False
    elif split=='text':
        bs = batch_size
        shuffle = True
        sets = train_set
        drop_too_long = True
    else:
        raise NotImplementedError
        
    if dataset.upper() == "TIMIT":
        assert not text_only,'TIMIT does not support text only.'
        ds = TimitDataset(file_path=data_path, sets=sets, max_timestep=max_timestep, 
                           max_label_len=max_label_len, bucket_size=bs)
    elif dataset.upper() =="LIBRISPEECH":
        ds = LibriDataset(file_path=data_path, sets=sets, max_timestep=max_timestep,text_only=text_only,
                           max_label_len=max_label_len, bucket_size=bs,drop=drop_too_long)
    else:
        raise ValueError('Unsupported Dataset: '+dataset)

    return  DataLoader(ds, batch_size=1,shuffle=shuffle,drop_last=False,num_workers=n_jobs,pin_memory=use_gpu)
  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值