自用AutoVC代码分析笔记

 分析流程从main函数开始,仅分析数据的变化

data_loader

学习了pytorch的data.Dataset

def __init__(self, root_dir, len_crop):
        """Initialize and preprocess the Utterances dataset."""
        self.root_dir = root_dir
        self.len_crop = len_crop
        self.step = 10
        
        metaname = os.path.join(self.root_dir, "train.pkl")
        meta = pickle.load(open(metaname, "rb"))
        
        """Load data using multiprocessing"""
#使用多进程加载数据,将数据分成多个步骤加载,以提高加载效率。加载的数据存储在 dataset 列表中
        manager = Manager()
        meta = manager.list(meta)
        dataset = manager.list(len(meta)*[None])  
#创建了一个共享的列表,用于在多个进程中存储加载的语音数据。每个元素最初都设置为 None,在加载数据时,会填充相应的语音数据。
        processes = []
        for i in range(0, len(meta), self.step):
            p = Process(target=self.load_data, 
                        args=(meta[i:i+self.step],dataset,i))  
            p.start()
            processes.append(p)
        for p in processes:
            p.join()
            
        self.train_dataset = list(dataset)#train.pkl中加载的
        self.num_tokens = len(self.train_dataset)
        
        print('Finished loading the dataset...')

需要重写的方法:__init__()和__getitem__()

__init__():定义了类中的全局变量

root_dir就是main.py中的data_dir==/spmel

meta是加载出来的train.pkl临时版,load_data()是根据mate中的路径加载对应npy文件,并将npy文件中的数据和meta中的说话人、说话人嵌入放在一起,最后会被加载到self.train_dataset里面,__init__()做了这些处理。

    def __getitem__(self, index):
        #index 表示当前要获取的样本在数据集中的索引
        # pick a random speaker
        dataset = self.train_dataset 
        list_uttrs = dataset[index] #某一说话人的所有数据
        emb_org = list_uttrs[1]  #说话人嵌入
        
        # pick random uttr with random crop
        a = np.random.randint(2, len(list_uttrs))
        tmp = list_uttrs[a] #随机挑一条语句
        if tmp.shape[0] < self.len_crop:#填充和截取至指定长度(128)
            len_pad = self.len_crop - tmp.shape[0]
            uttr = np.pad(tmp, ((0,len_pad),(0,0)), 'constant')
        elif tmp.shape[0] > self.len_crop:
            left = np.random.randint(tmp.shape[0]-self.len_crop)
            uttr = tmp[left:left+self.len_crop, :]
        else:
            uttr = tmp
        
        return uttr, emb_org

最后data_loader,是data.DataLoader的一个对象

    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  drop_last=True,
                                  worker_init_fn=worker_init_fn)

Solver_encoder

def __init__(self, vcc_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.vcc_loader = vcc_loader

        # Model configurations.
        self.lambda_cd = config.lambda_cd
        self.dim_neck = config.dim_neck
        self.dim_emb = config.dim_emb
        self.dim_pre = config.dim_pre
        self.freq = config.freq

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        
        # Miscellaneous.
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
        self.log_step = config.log_step

        self.sample_step = config.sample_step   #MARK
        self.model_save_step = config.model_save_step
        self.model_save_dir = config.model_save_dir
        self.resume_iters = config.resume_iters

        # Build the model and tensorboard.
        self.build_model()

__init__()加载参数,以及数据,见main函数中的参数配置

    parser.add_argument('--dim_neck', type=int, default=16)
    parser.add_argument('--dim_emb', type=int, default=256)
    parser.add_argument('--dim_pre', type=int, default=512)
    parser.add_argument('--freq', type=int, default=16)
    parser.add_argument('--batch_size', type=int, default=2, help='mini-batch size')
    parser.add_argument('--len_crop', type=int, default=128, help='dataloader output sequence length')

train的代码很长

加载数据,这里batchsize=2,x_real中每一个元素长度为128,emb_org的长度为256

x_real只mel谱,emb_org是说话人嵌入。

 try:
                x_real, emb_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, emb_org = next(data_iter)
            
            x_real = x_real.to(self.device) 
            emb_org = emb_org.to(self.device) 

 Solver_encoder中加载了一个生成器模型,此模型来自model_vc,且传过去的是main中写好的定值参数,而不是数据

from model_vc import Generator


def build_model(self):
        
        self.G = Generator(self.dim_neck, self.dim_emb, self.dim_pre, self.freq)        
        
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), 0.0001)
        
        self.G.to(self.device)
        

self.G = self.G.train() 

Generator中有三个模块:Encoder,Decoder和Postnet

class Generator(nn.Module):
    """Generator network."""
    def __init__(self, dim_neck, dim_emb, dim_pre, freq):
        super(Generator, self).__init__()
        
        self.encoder = Encoder(dim_neck, dim_emb, freq)
        self.decoder = Decoder(dim_neck, dim_emb, dim_pre)
        self.postnet = Postnet()

    def forward(self, x, c_org, c_trg):
                
        codes = self.encoder(x, c_org)
        if c_trg is None:
            return torch.cat(codes, dim=-1)
        
        tmp = []
        for code in codes:
            tmp.append(code.unsqueeze(1).expand(-1,int(x.size(1)/len(codes)),-1))
        code_exp = torch.cat(tmp, dim=1)
        
        encoder_outputs = torch.cat((code_exp, c_trg.unsqueeze(1).expand(-1,x.size(1),-1)), dim=-1)
        
        mel_outputs = self.decoder(encoder_outputs)
                
        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2,1))
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2,1)
        
        mel_outputs = mel_outputs.unsqueeze(1)
        mel_outputs_postnet = mel_outputs_postnet.unsqueeze(1)
        
        return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值