【项目实战】AUTOVC 代码解析 —— solver_encoder.py

AUTOVC 代码解析 —— solver_encoder.py

  简介

       本项目一个基于 AUTOVC 模型的语音转换项目,它是使用 PyTorch 实现的(项目地址)。
       
        AUTOVC 遵循自动编码器框架,只对自动编码器损耗进行训练,但它引入了精心调整的降维和时间下采样来约束信息流,这个简单的方案带来了显著的性能提高。(详情请参阅 AUTOVC 的详细介绍)。
       
       由于 AUTOVC 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
       
       本文将介绍项目中的 solver_encoder.py 文件:设计了网络模型的解决方案。
       

  类解析

    Solver

        该类的意义为:网络模型的解决方案。
       
       下面依次介绍 Solver 类的成员函数
       

      __ init __

          该函数的作用是: 创建 Solver 网络模型解决方案需要的元素。

          输入参数:

		vcc_loader	:	数据迭代器
		config		:	网路模型配置

          输出参数:

          代码详解:

	    def __init__(self, vcc_loader, config):
	        """ 初始化配置 """
	
	        # 数据迭代器
	        self.vcc_loader = vcc_loader
	
	        # 模型配置
	        self.lambda_cd = config.lambda_cd
	        self.dim_neck = config.dim_neck
	        self.dim_emb = config.dim_emb
	        self.dim_pre = config.dim_pre
	        self.freq = config.freq
	
	        # 训练配置
	        self.batch_size = config.batch_size
	        # 训练次数
	        self.num_iters = config.num_iters
	        
	        # 杂项
	        # 检测 GPU
	        self.use_cuda = torch.cuda.is_available()
	        # 选择 GPU
	        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
	        # 日志存储间隔
	        self.log_step = config.log_step
	
	        # 建立模型与张量记录
	        self.build_model()
      build_model

          该函数的作用是: 创建语音转换网络模型,配置 Adam 优化器,并将数据移至选定的设备中。

          输入参数:

          输出参数:

          代码详解:

	    def build_model(self):
	        
	        # 创建生成器,内容编码长度为 dim_neck ,说话人编码长度为 dim_emb ,解码长度为 dim_pre ,采样系数为 freq
	        # 创建成员变量 G ,应用上述组装的生成器
	        self.G = Generator(self.dim_neck, self.dim_emb, self.dim_pre, self.freq)        
	        
	        # 创建 Adam 优化器,设置学习速率为 0.0001 ,G.parameters() 为可迭代的参数优化
	        # 创建成员变量 g_optimizer ,应用上述组装的 Adam 优化器
参数组
	        self.g_optimizer = torch.optim.Adam(self.G.parameters(), 0.0001)
	        
	        # 将网络模型移至实现设置的设备中
	        self.G.to(self.device)
      reset_grad

          该函数的作用是: 重置梯度缓冲器

          输入参数:

          输出参数:

          代码详解:

	    def reset_grad(self):
	        """ 重置梯度缓冲器 """
	        self.g_optimizer.zero_grad()
      train

          该函数的作用是: 训练网络模型

          输入参数:

          输出参数:

          代码详解:

	    def train(self):
	        # 设置数据迭代器
	        data_loader = self.vcc_loader
	        
	        # 按指定顺序打印日志
	        keys = ['G/loss_id','G/loss_id_psnt','G/loss_cd']
	            
	        # 开始训练
	        print('Start training...')
	        # 训练开始时间
	        start_time = time.time()
	        # 训练 num_iters 次
	        for i in range(self.num_iters):
	
	            # =================================================================================== #
	            #                             1. 预处理输入的数据                                #
	            # =================================================================================== #
	
	            # 取数据
	            try:
	                # x_real 为长度为截取长度的梅尔频谱数据,emb_org 为说话人编码
	                x_real, emb_org = next(data_iter)
	            except:
	                # 若失败,则设置数据迭代器,再次尝试取数据
	                data_iter = iter(data_loader)
	                x_real, emb_org = next(data_iter)
	            
	            
	            # 将长度为截取长度的梅尔频谱数据与说话人编码移入设定的设备内
	            x_real = x_real.to(self.device) 
	            emb_org = emb_org.to(self.device) 
	                        
	       
	            # =================================================================================== #
	            #                               2. 训练生成器                                #
	            # =================================================================================== #
	            
	            # 启用 Batch Normalization 和 Dropout
	            self.G = self.G.train()
	                        
	            # 前向传播,更新参数,标识映射损失
	            # 输入音频数据为 x_real ,源说话人与目标说话人都为 emb_org
	            # 得到初步转换结果 x_identic ,最终转换结果 x_identic_psnt ,内容编码 code_real
	            x_identic, x_identic_psnt, code_real = self.G(x_real, emb_org, emb_org)
	            # 计算初步转换结果的误差 g_loss_id (均方损失函数)
	            g_loss_id = F.mse_loss(x_real, x_identic)   
	            # 计算最终转换结果的误差 g_loss_id_psnt
	            g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt)   
	            
	            # 计算代码语义 code_reconst 
	            code_reconst = self.G(x_identic_psnt, emb_org, None)
	            # 计算代码语音损失 g_loss_cd (L1 损失)
	            g_loss_cd = F.l1_loss(code_real, code_reconst)
	
	
	            # 向后传播和优化
	            # 计算综合损失值
	            g_loss = g_loss_id + g_loss_id_psnt + self.lambda_cd * g_loss_cd
	            # 重置梯度缓冲器
	            self.reset_grad()
	            # 后向传播
	            g_loss.backward()
	            # 执行优化
	            self.g_optimizer.step()
	
	            # 记录日志
	            loss = {}
	            loss['G/loss_id'] = g_loss_id.item()
	            loss['G/loss_id_psnt'] = g_loss_id_psnt.item()
	            loss['G/loss_cd'] = g_loss_cd.item()
	
	            # =================================================================================== #
	            #                                 4. 杂项                                   #
	            # =================================================================================== #
	
	            # 每隔 log_step 次训练,打印训练信息
	            if (i+1) % self.log_step == 0:
	                # 计算当前训练消耗时间
	                et = time.time() - start_time
	                # 保留整数秒
	                et = str(datetime.timedelta(seconds=et))[:-7]
	                # 打印消耗时间与训练进度
	                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
	                # 打印当前训练时刻的损失日志
	                for tag in keys:
	                    log += ", {}: {:.4f}".format(tag, loss[tag])
	                print(log)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值