AudioReader 是Dual-Path-RNN-Pytorch源码中用来读入scp文件的工具类,它的输入是一个scp文件的全路径。
有关scp文件,它是一个文件与文件路径的记录文件。 内容样例如下
1.wav /local/file/path/1.wav 2.wav /local/file/path/2.wav 3.wav /local/file/path/3.wav 4.wav /local/file/path/4.wav 5.wav /local/file/path/5.wav 6.wav /local/file/path/6.wav
AudioReader 通过torchaudio读入文件中的每一个音频文件,这里注意一下,有些同学可能安装完torchaudio后但是会有如下错误:
No audio I/O backend is available
torchaudio需要一个具体“做事”的软件(backend)去执行载入音频文件,所以需要安装它。
torchaudio在windows下的backend默认为soundfile
使用pip install SoundFile 即可以安装
torchaudio在linux下的backend默认为sox
所以它的读操作如下,即使用了torchaudio.load
def read_wav(fname, return_rate=False):
'''
Read wavfile using Pytorch audio
input:
fname: wav file path
return_rate: Whether to return the sampling rate
output:
src: output tensor of size C x L
L is the number of audio frames
C is the number of channels.
sr: sample rate
'''
src, sr = torchaudio.load(fname, channels_first=True)
if return_rate:
return src.squeeze(), sr
else:
return src.squeeze()
AudioReader 类本身可以理解为一个字典(dict),它的内部实现形式为
{
'1.wav': '/local/file/path/1.wav',
'2.wav': '/local/file/path/2.wav',
'3.wav': '/local/file/path/3.wav',
'4.wav': '/local/file/path/4.wav',
'5.wav': '/local/file/path/5.wav',
'6.wav': '/local/file/path/6.wav'
}
但是实际通过字典返回时,是把对应的文件load完返回数据的。这样方便调用者使用。
同时,它扩展了字典key值,允许使用int来索引。总结AudioReader不复杂,但是设计上还是不错的。
下面是全部源码,请参考
def read_wav(fname, return_rate=False):
'''
Read wavfile using Pytorch audio
input:
fname: wav file path
return_rate: Whether to return the sampling rate
output:
src: output tensor of size C x L
L is the number of audio frames
C is the number of channels.
sr: sample rate
'''
src, sr = torchaudio.load(fname, channels_first=True)
if return_rate:
return src.squeeze(), sr
else:
return src.squeeze()
def write_wav(fname, src, sample_rate):
'''
Write wav file
input:
fname: wav file path
src: frames of audio
sample_rate: An integer which is the sample rate of the audio
output:
None
'''
torchaudio.save(fname, src, sample_rate)
class AudioReader(object):
'''
Class that reads Wav format files
Input as a different scp file address
Output a matrix of wav files in all scp files.
'''
def __init__(self, scp_path, sample_rate=8000):
super(AudioReader, self).__init__()
self.sample_rate = sample_rate
self.index_dict = handle_scp(scp_path)
self.keys = list(self.index_dict.keys())
def _load(self, key):
src, sr = read_wav(self.index_dict[key], return_rate=True)
if self.sample_rate is not None and sr != self.sample_rate:
raise RuntimeError('SampleRate mismatch: {:d} vs {:d}'.format(
sr, self.sample_rate))
return src
def __len__(self):
return len(self.keys)
def __iter__(self):
for key in self.keys:
yield key, self._load(key)
def __getitem__(self, index):
if type(index) not in [int, str]:
raise IndexError('Unsupported index type: {}'.format(type(index)))
if type(index) == int:
num_uttrs = len(self.keys)
if num_uttrs < index and index < 0:
raise KeyError('Interger index out of range, {:d} vs {:d}'.format(
index, num_uttrs))
index = self.keys[index]
if index not in self.index_dict:
raise KeyError("Missing utterance {}!".format(index))
return self._load(index)