分析流程从main函数开始,仅分析数据的变化
data_loader
学习了pytorch的data.Dataset
def __init__(self, root_dir, len_crop):
"""Initialize and preprocess the Utterances dataset."""
self.root_dir = root_dir
self.len_crop = len_crop
self.step = 10
metaname = os.path.join(self.root_dir, "train.pkl")
meta = pickle.load(open(metaname, "rb"))
"""Load data using multiprocessing"""
#使用多进程加载数据,将数据分成多个步骤加载,以提高加载效率。加载的数据存储在 dataset 列表中
manager = Manager()
meta = manager.list(meta)
dataset = manager.list(len(meta)*[None])
#创建了一个共享的列表,用于在多个进程中存储加载的语音数据。每个元素最初都设置为 None,在加载数据时,会填充相应的语音数据。
processes = []
for i in range(0, len(meta), self.step):
p = Process(target=self.load_data,
args=(meta[i:i+self.step],dataset,i))
p.start()
processes.append(p)
for p in processes:
p.join()
self.train_dataset = list(dataset)#train.pkl中加载的
self.num_tokens = len(self.train_dataset)
print('Finished loading the dataset...')
需要重写的方法:__init__()和__getitem__()
__init__():定义了类中的全局变量
root_dir就是main.py中的data_dir==/spmel
meta是加载出来的train.pkl临时版,load_data()是根据mate中的路径加载对应npy文件,并将npy文件中的数据和meta中的说话人、说话人嵌入放在一起,最后会被加载到self.train_dataset里面,__init__()做了这些处理。
def __getitem__(self, index):
#index 表示当前要获取的样本在数据集中的索引
# pick a random speaker
dataset = self.train_dataset
list_uttrs = dataset[index] #某一说话人的所有数据
emb_org = list_uttrs[1] #说话人嵌入
# pick random uttr with random crop
a = np.random.randint(2, len(list_uttrs))
tmp = list_uttrs[a] #随机挑一条语句
if tmp.shape[0] < self.len_crop:#填充和截取至指定长度(128)
len_pad = self.len_crop - tmp.shape[0]
uttr = np.pad(tmp, ((0,len_pad),(0,0)), 'constant')
elif tmp.shape[0] > self.len_crop:
left = np.random.randint(tmp.shape[0]-self.len_crop)
uttr = tmp[left:left+self.len_crop, :]
else:
uttr = tmp
return uttr, emb_org
最后data_loader,是data.DataLoader的一个对象
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
worker_init_fn=worker_init_fn)
Solver_encoder
def __init__(self, vcc_loader, config):
"""Initialize configurations."""
# Data loader.
self.vcc_loader = vcc_loader
# Model configurations.
self.lambda_cd = config.lambda_cd
self.dim_neck = config.dim_neck
self.dim_emb = config.dim_emb
self.dim_pre = config.dim_pre
self.freq = config.freq
# Training configurations.
self.batch_size = config.batch_size
self.num_iters = config.num_iters
# Miscellaneous.
self.use_cuda = torch.cuda.is_available()
self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
self.log_step = config.log_step
self.sample_step = config.sample_step #MARK
self.model_save_step = config.model_save_step
self.model_save_dir = config.model_save_dir
self.resume_iters = config.resume_iters
# Build the model and tensorboard.
self.build_model()
__init__()加载参数,以及数据,见main函数中的参数配置
parser.add_argument('--dim_neck', type=int, default=16)
parser.add_argument('--dim_emb', type=int, default=256)
parser.add_argument('--dim_pre', type=int, default=512)
parser.add_argument('--freq', type=int, default=16)
parser.add_argument('--batch_size', type=int, default=2, help='mini-batch size')
parser.add_argument('--len_crop', type=int, default=128, help='dataloader output sequence length')
train的代码很长
加载数据,这里batchsize=2,x_real中每一个元素长度为128,emb_org的长度为256
x_real只mel谱,emb_org是说话人嵌入。
try:
x_real, emb_org = next(data_iter)
except:
data_iter = iter(data_loader)
x_real, emb_org = next(data_iter)
x_real = x_real.to(self.device)
emb_org = emb_org.to(self.device)
Solver_encoder中加载了一个生成器模型,此模型来自model_vc,且传过去的是main中写好的定值参数,而不是数据
from model_vc import Generator
def build_model(self):
self.G = Generator(self.dim_neck, self.dim_emb, self.dim_pre, self.freq)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), 0.0001)
self.G.to(self.device)
self.G = self.G.train()
Generator中有三个模块:Encoder,Decoder和Postnet
class Generator(nn.Module):
"""Generator network."""
def __init__(self, dim_neck, dim_emb, dim_pre, freq):
super(Generator, self).__init__()
self.encoder = Encoder(dim_neck, dim_emb, freq)
self.decoder = Decoder(dim_neck, dim_emb, dim_pre)
self.postnet = Postnet()
def forward(self, x, c_org, c_trg):
codes = self.encoder(x, c_org)
if c_trg is None:
return torch.cat(codes, dim=-1)
tmp = []
for code in codes:
tmp.append(code.unsqueeze(1).expand(-1,int(x.size(1)/len(codes)),-1))
code_exp = torch.cat(tmp, dim=1)
encoder_outputs = torch.cat((code_exp, c_trg.unsqueeze(1).expand(-1,x.size(1),-1)), dim=-1)
mel_outputs = self.decoder(encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs.transpose(2,1))
mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2,1)
mel_outputs = mel_outputs.unsqueeze(1)
mel_outputs_postnet = mel_outputs_postnet.unsqueeze(1)
return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1)