2024数字媒体技术大三深度学习案例三训练代码研究

代码一段时间不看就烂了,又忘记是啥意思,受不了了来写个研究,方便之后看。

超参数等基础设置在后面训练中遇到了再讲。所有在后面程序会讲到的都会用红色标注,然后用蓝色来回应~~~。块引用(也就是这三行字所处的灰框框)代码表示是训练代码之前的准备部分,或是补充。我所说的“默认”代表老师的设置。

# ----------
#  Training
# ----------

直接从Training开始

准备工作

generator.train()
discriminator.train()
prev_time = time.time()

 开启生成器和判别器的训练。prev_time一眼没啥用,实验要求也和它没啥关系,我甚至想把它删了(但估计每次pull下来的程序里又有它捏)

def lr_scheduler(optimizer, init_lr, epoch, lr_decay_iter):
    if epoch % lr_decay_iter:
        return init_lr
    lr = init_lr * 0.5
    optimizer.param_groups[0]['lr'] = lr
    return lr

学习率衰减函数,按照此代码逻辑是每lr_decay_iter个轮次让学习率变为原来的一半。并在此按照另一个深度学习文章的学习率设置,要求最后的学习率为开始的1/100甚至更小。https://blog.csdn.net/JNingWei/article/details/79243800

下面马上就有它出现啦。

min_tloss = 500
tloss_res = {}

这俩玩意后面再讲。

接下来进行epoch次数训练。次数由opt.n_epochs - opt.epoch得到,按照初始超参数的设置为200 - 0 = 200次

    ch_lr_avg_loss_depart = []
    ch_lr_avg_loss = 0

没错还是放后面,不过很快就到你了!

    if epoch > 0:
        learning_rate_G = lr_scheduler(optimizer_G, learning_rate_G,epoch+1, opt.lrgd)
        # learning_rate_D = lr_scheduler(optimizer_D, learning_rate_D,epoch+1, opt.lrdd)

按照老师的说法,本次实验不需要对判别器进行优化。而优化的函数就是上面提到的lr_scheduler,其中除epoch的参数如下:

optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate_G, betas=(opt.b1, opt.b2))

生成器优化器如上,其lr是本次实验需要调整的超参数之一。

learning_rate_G = opt.learning_rate_G

生成器学习率的初始值设为超参数,默认为1e-4。

而opt.lrgd点明了其为超参数,默认值为90。

那么回过头去看lr_scheduler,可以发现每90轮才会离开函数中的if (如8 % 5会判断为真),也就是说每90轮才衰减一次。这和链接所说的1/100可相差甚远。

    optimizer.param_groups[0]['lr'] = lr

根据上面optinizer_G的定义中有'lr',我猜想这个是把lr的值传给优化器的参数组,以更新学习率的效果。没进if时,优化器的参数不需要变化,所以if中不需要含这行代码。

    for i, batch in enumerate(dataloader):

enumerate(dataloader) 是一个 Python 内置函数,用于添加计数器到迭代对象 dataloader,它返回 (index, value) 对。按次定义,这里的 i 是批次的索引或编号,它从 0 开始并在每次迭代时加一,而 batch 是数据加载器 dataloader 的下一个元素,包含了一批训练数据。

因此我觉得此处的i更像保存结果中的Batch

        # Model inputs
        real_A = batch['B'].type(Tensor)
        real_B = batch['A'].type(Tensor)

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

单看这个一头雾水,real_A怎么和B有关?下面又反过来了?

生成器

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)
        # ipdb.set_trace()
        pred_fake = discriminator(fake_B, real_A)
        # ipdb.set_trace()
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        ch_lr_avg_loss_depart.append(loss_G.data.item())
        # 反向传播是通过链式法则自动计算模型所有参数的损失函数的梯度的过程,这些梯度指示了为了减少loss,每个参数应该如何变化
        loss_G.backward()
        # 根据上面的梯度来更新模型的参数
        optimizer_G.step()

结合生成器代码并参考Pix2Pix的原理,我们可知:

real_A:源图像,即下图imgA

real_B:目标图像,本实验中也称风格图像

fake_B:生成图像,即下图imgB

valid 和 fake 是用来为真实样本和生成样本提供目标标签的,真实样本应该被标记为“真”(1),生成的样本应该被标记为“假”(0)。这些张量通常在训练GAN时是固定的,并不参与梯度下降。

 接下来解析一下生成器代码

optimizer_G.zero_grad()

表示将生成器的所有参数的梯度值清零,避免在参数更新时产生错误的累积效果,保证每个批次的训练梯度仅与该批次的数据有关。

        # GAN loss
        fake_B = generator(real_A)
        # ipdb.set_trace()
        pred_fake = discriminator(fake_B, real_A)
        # ipdb.set_trace()
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

pred_fake 是判别器对于生成的假数据 fake_B 的输出,同时 valid 是一个全1的张量,其形状与 pred_fake 一致,表示真实样本标签。这一行将判别器的输出(它对假样本的预测)和代表真实样本的目标 valid 进行比较,计算生成器的对抗性损失 loss_GAN。


这段代码中的生成器的目标是试图欺骗判别器,使其把生成的假数据 fake_B 判断为真实数据。因此,生成器的损失 loss_GAN 计算了判别器正确识别出 fake_B 不是真实样本而产生的误差。通过最小化这个损失,生成器将学习如何生成更加真实的数据以迷惑判别器。

同理,loss_pixel代表着生成图像和目标风格图像的差距。

ch_lr_avg_loss_depart有点意思了,这里结合下文代码做一个缩略:

# ...

for epoch in range(opt.epoch, opt.n_epochs):
    
    # 每个 epoch 开始,初始化损失列表
    ch_lr_avg_loss_depart = []
    
    # 其他代码 ...
    
    for i, batch in enumerate(dataloader):
        
        #...
        
        # 计算生成器损失并更新权重
        loss_G = loss_GAN + lambda_pixel * loss_pixel
        
        # 将当前损失加入列表
        ch_lr_avg_loss_depart.append(loss_G.data.item())
    
    # 通过损失列表计算平均损失
    ch_lr_avg_loss = sum(ch_lr_avg_loss_depart) / len(ch_lr_avg_loss_depart)
    
    # 其他代码 ...

ch_lr_avg_loss_depart 用于存储每个 batch 的生成器损失。在完成一个 epoch 的所有 batch 处理后,这个列表中包含了该 epoch 中所有 batch 的生成器损失。然后,代码通过计算这个列表所有数值的平均值(sum(ch_lr_avg_loss_depart) / len(ch_lr_avg_loss_depart)),来得到整个 epoch 的平均生成器损失。

生成器代码中含注释的就不讲啦,也讲不明白。loss_G就是那俩玩意相加,可能有优化空间?

判别器

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        # 使用detach来防止在判别器训练过程中计算关于生成器参数的梯度
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()

看着和生成器的逻辑一样,就不赘述咧。

数据保存

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        print("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [ pixel: %f, loss_GAN: %f] ETA: %s" %
              (epoch, opt.n_epochs,
               i, len(dataloader),
               loss_D.item(),
               lambda_pixel * loss_pixel.item(), loss_GAN.item(),
               time_left))

        with io.open(pix_path + '/loss_message/%s/train_loss.txt' % opt.dataset_name, 'a', encoding='utf-8') as file:
            file.write(
                '[Epoch: {}] [Dloss: {:.4f}] [loss_pixel: {:.4f}] [loss_GAN: {:.4f}] [loss_real: {:.4f}] [loss_fake: {:.4f}] [Batch: {}/{}] \n'
                .format(epoch, loss_D.item(), lambda_pixel * loss_pixel.item(), loss_GAN.item(), loss_real.item(),
                        loss_fake.item(), i, len(dataloader)))

        # If at sample interval save image
        if batches_done % opt.sample_interval == 0:
            sample_images(batches_done)

    # 计算平均loss和时间
    # check learning rate average loss 检查学习率对应的平均损失
    ch_lr_avg_loss = sum(ch_lr_avg_loss_depart) / len(ch_lr_avg_loss_depart)

    print('----------------------------------------------------------- \n')
    print('avg_loss: {:.4f} \n'.format(ch_lr_avg_loss))

    with io.open(pix_path + '/loss_message/%s/loss_time.txt' % opt.dataset_name, 'a', encoding='utf-8') as file:
        file.write('[avg_loss: {:.4f}] \n'.format(ch_lr_avg_loss))

    avg_loss = 0
    avg_loss = loss_val()
    tloss_res[epoch] = avg_loss

    # 每50轮保存模型参数,save函数中使用.state_dict()是只保存参数而不是整个模型哦
    if epoch > 0 and (epoch + 1) % 50 == 0:
        print()
        torch.save(generator.state_dict(), pix_path + '/saved_models/%s/generator_%d.pth' % (opt.dataset_name, epoch))
        torch.save(discriminator.state_dict(),
                   pix_path + '/saved_models/%s/d/discriminator_%d.pth' % (opt.dataset_name, epoch))
    # 保存loss最小时的模型参数
    if tloss_res[epoch] < min_tloss:
        min_tloss = tloss_res[epoch]
        tloss_res['min'] = tloss_res[epoch]
        tloss_res['minepoch'] = epoch
        torch.save(generator.state_dict(), pix_path + '/saved_models/%s/generator_min.pth' % (opt.dataset_name))
        torch.save(discriminator.state_dict(),
                   pix_path + '/saved_models/%s/d/discriminator_min.pth' % (opt.dataset_name))

with io.open(pix_path + '/loss_message/%s/list_loss.txt' % opt.dataset_name, 'a', encoding='utf-8') as file:
    file.write('tloss_res: {} \n'.format(tloss_res))

这部分主要是文件和数据的保存,感觉没啥可讲的awa

前面的红字懒得管了,寄awa

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值