首先是数据集处理部分,对['speaker_id', 'file_name', 'path', 'sys_id', 'key']
通过torch.save
保存为npy文件。
cm_protocol可以是ASV2019路径下ASVspoof2019_LA_cm_protocols
文件夹的ASVspoof2019.LA.cm.train.trn.txt
文件。
最核心的部分在class ASVDataset(Dataset)的def __init__
。
#cm_protocol可以是ASVspoof2019_LA_cm_protocols文件夹的ASVspoof2019.LA.cm.train.trn.txt文件
#返回ASVFile(speaker_id,file_name,path,sys_id,key)
files_meta = parse_protocols_file(cm_protocol)
#返回音频采样点data_x, 标签data_y, 第几个篡改方法sys_id
data = list(map(read_file, files_meta))
data_x, data_y, data_sysid = map(list, zip(*data))
#保存
torch.save((self.data_x, self.data_y, self.data_sysid, self.files_meta), self.cache_fname)
此外,加入了torch.vision.transform()以后,torch.save前加入一行:
if self.transform:
self.data_x = Parallel(n_jobs=4, prefer='threads')(delayed(self.transform)(x) for x in self.data_x)
其中,transform使用:
transforms = transforms.Compose([
lambda x: pad(x),
lambda x: Tensor(x)
])
pad过程,填充至64600样本点。
max_len为64600,超过max_len就截段x[:max_len]
,低于max_len就num_repeats循环复制直到填充至64600样本点np.tile(x, (1, num_repeats))[:, :max_len][0]
def pad(x, max_len=64600):
x_len = x.shape[0]
if x_len >= max_len:
return x[:max_len]
# need to pad
num_repeats = int(max_len / x_len)+1 #向上取整
padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
return padded_x
解析parse_protocols_file:
def parse_protocols_file(self, protocols_fname):
#读取protocols的每一行,输入到_parse_line中处理,返回files_meta
lines = open(protocols_fname).readlines()
files_meta = map(self._parse_line, lines)
return list(files_meta)
#ASVFile为tuple结构
ASVFile = collections.namedtuple('ASVFile',
['speaker_id', 'file_name', 'path', 'sys_id', 'key'])
# 读取protocols,返回ASVFile的tuple=['speaker_id', 'file_name', 'path', 'sys_id', 'key']
def _parse_line(self, line):
tokens = line.strip().split(' ')
if self.is_eval:
return ASVFile(speaker_id=tokens[0],
file_name=tokens[1],
path=os.path.join(self.files_dir, tokens[1] + '.flac'),
sys_id=self.sysid_dict[tokens[3]],
key=int(tokens[4] == 'bonafide'))
return ASVFile(speaker_id=tokens[0],
file_name=tokens[1],
path=os.path.join(self.files_dir, tokens[1] + '.flac'),
sys_id=self.sysid_dict[tokens[3]],
key=int(tokens[4] == 'bonafide'))
sys_id也根据train/dev和eval的不同而不同,因为train和dev只有A01到A06,eval部分在训练集不可见,有A07到A19。部分代码如下:
if is_eval :
#eval部分
self.sysid_dict = {
'-': 0, # bonafide speech
'A07': 1,
'A08': 2,
'A09': 3,
'A10': 4,
'A11': 5,
'A12': 6,
'A13': 7,
'A14': 8,
'A15': 9,
'A16': 10,
'A17': 11,
'A18': 12,
'A19': 13,
}
else:
#train/dev部分
self.sysid_dict = {
'-': 0, # bonafide speech
'A01': 1,
'A02': 2,
'A03': 3,
'A04': 4,
'A05': 5,
'A06': 6,
}
解析后,通过read_file返回data_x,data_y,sys_id。
#meta中有['speaker_id', 'file_name', 'path', 'sys_id', 'key']
def read_file(self, meta):
data_x, sample_rate = sf.read(meta.path)#读取flac路径,将音频采样点保存到data_x中
data_y = meta.key# 是否是真实样本,0假,1真。
return data_x, float(data_y), meta.sys_id