运行一段博客上的代码:
# encoding: utf-8
import tensorflow as tf # 0.12
import numpy as np
import os
from collections import Counter
import librosa # https://github.com/librosa/librosa
wav_path = '/home/liuyuan/Desktop/data/wav/train'
label_file = '/home/liuyuan/Desktop/data/doc/trans/train.word.txt'
def get_wav_files(wav_path=wav_path):
wav_files = []
for (dirpath, dirnames, filenames) in os.walk(wav_path):
for filename in filenames:
if filename.endswith('.wav') or filename.endswith('.WAV'):
filename_path = os.sep.join([dirpath, filename])
if os.stat(filename_path).st_size < 240000:
continue
wav_files.append(filename_path)
return wav_files
wav_files = get_wav_files()
def get_wav_lable(wav_files=wav_files, label_file=label_file):
labels_dict = {}
with open(label_file, 'r') as f:
for label in f:
label = label.strip('\n')
label_id = label.split(' ', 1)[0]
label_text = label.split(' ', 1)[1]
labels_dict[label_id] = label_text
labels = []
new_wav_files = []
for wav_file in wav_files:
wav_id = os.path.basename(wav_file).split('.')[0]
if wav_id in labels_dict:
labels.append(labels_dict[wav_id])
new_wav_files.append(wav_file)
return new_wav_files, labels
wav_files, labels = get_wav_lable()
print("样本数:", len(wav_files)) # 8911
# print(wav_files[0], labels[0])
# wav/train/A11/A11_0.WAV -> 绿 是 阳春 烟 景 大块 文章 的 底色 四月 的 林 峦 更是 绿 得 鲜活 秀媚 诗意 盎然
# 词汇表(参看练习1和7)
all_words = []
for label in labels:
all_words += [word for word in label]
counter = Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
words_size = len(words)
print('词汇表大小:', words_size)
word_num_map = dict(zip(words, range(len(words))))
to_num = lambda word: word_num_map.get(word, len(words))
labels_vector = [list(map(to_num, label)) for label in labels]
# print(wavs_file[0], labels_vector[0])
# wav/train/A11/A11_0.WAV -> [479, 0, 7, 0, 138, 268, 0, 222, 0, 714, 0, 23, 261, 0, 28, 1191, 0, 1, 0, 442, 199, 0, 72, 38, 0, 1, 0, 463, 0, 1184, 0, 269, 7, 0, 479, 0, 70, 0, 816, 254, 0, 675, 1707, 0, 1255, 136, 0, 2020, 91]
# print(words[479]) #绿
label_max_len = np.max([len(label) for label in labels_vector])
print('最长句子的字数:', label_max_len)
wav_max_len = 0 # 673
for wav in wav_files:
wav, sr = librosa.load(wav, mono=True)
mfcc = np.transpose(librosa.feature.mfcc(wav, sr), [1, 0])
if len(mfcc) > wav_max_len:
wav_max_len = len(mfcc)
print("最长的语音:", wav_max_len)
batch_size = 16
n_batch = len(wav_files) // batch_size
# 获得一个batch
pointer = 0
def get_next_batches(batch_size):
global pointer
batches_wavs = []
batches_labels = []
for i in range(batch_size):
wav, sr = librosa.load(wav_files[pointer], mono=True)
mfcc = np.transpose(librosa.feature.mfcc(wav, sr), [1, 0])
batches_wavs.append(mfcc.tolist())
batches_labels.append(labels_vector[pointer])
pointer +=