一、Datasets大体格式
class SDataset(Dataset):
"""Dataset to load NSynth data."""
def __init__(self, audio_dir):
super().__init__()
'''
一般用于加载数据,比如对某个路径下的语音遍历,设置语音系统初始值。
'''
def __len__(self):
'''
返回长度
'''
def __getitem__(self, index):
'''
一般用于对加载后的语音信号进行处理,比如标准化,去噪,截取语音小段等
'''
二、__init__环节举例
通常必备的语音初始化环节
def __init__(self, audio_dir):
super().__init__()
self.segment_length = 8000 # 采样率
self.filesnames =[]
self.filesnames .extend(glob.glob(audio_dir+"/*.wav")) #加载语音
三、标准化
为了方便训练,通常会有一个标准化的过程。(可选)
def normalize_data(self, x):
x = x / max(abs(x)+1e-20)
return x
四、__getitem__环节举例
必备的语音处理环节,比如对语音的某一段进行截取,而不是全部输入语音。比如消除一些静音段
def __getitem__(self, index):
seq_max = 秒 x self.segment_length # 期望处理的语音片段
audio = torchaudio.load(self.filenames[index])[0] # 加载数据
ans = torch.zeros(1, seq_max) # 最大长度语音
'''
按照长度和标准差筛选语音
'''
while len(audio ) < seq_max or np.isclose(np.std(audio), 0):
index= (index+ 1) % self.__len__()
file_path = self.filesnames [index]
sr, audio = wavfile.read(file_path)
audio = self.normalize_data(audio)
audio = torchaudio.load(self.filenames[index])[0]
'''
如果长度大于最大语音片段,进行处理
'''
while 1:
if len(in_data) == seq_max :
loc = 0
else:
loc = torch.randint(len(audio)-seq_max , (1,))
selected_seg = in_data[loc: loc + seq_max]
if not np.isclose(np.std(selected_seg), 0):
seg = selected_seg
break