首先,在导入部分添加以下内容:
import torch.multiprocessing as mp
在 __main__
部分的开始处,使用 mp.Value
创建一个列表,然后将其传递给 Dataset
类:
if __name__ == "__main__":
mel_spectrograms = mp.Manager().list()
train_dataset = Dataset('train', mel_spectrograms)
test_dataset = Dataset('val', mel_spectrograms)
# ...
然后,修改 Dataset
类的 __init__
方法,以便它接受一个额外的参数 mel_spectrograms
:
class Dataset(object):
def __init__(self, split, mel_spectrograms):
self.all_videos = get_image_list(args.data_root, split)
self.mel_spectrograms = mel_spectrograms
最后,将 Dataset
类中所有引用 mel_spectrograms
全局变量的地方替换为 self.mel_spectrograms
。
这样,所有子进程都将使用同一个 mel_spectrograms
列表,从而避免了覆盖已计算值的问题。此外,使用 torch.multiprocessing
模块应该比使用 Manager
类具有更好的性能。
请注意,由于我们现在使用的是列表而不是字典,您需要相应地更新 __getitem__
方法中与 mel_spectrograms
相关的代码。您可以将 vidname
作为键添加到 mel_spectrograms
列表中,以便在需要时检索它。
try:
wavpath = join(vidname, "audio.wav")
found = False
for item in self.mel_spectrograms:
if item[0] == vidname:
orig_mel = item[1]
found = True
break
if not found:
wav = audio.load_wav(wavpath, hparams.sample_rate)
orig_mel = audio.melspectrogram(wav).T
self.mel_spectrograms.append((vidname, orig_mel))
except Exception as e:
continue