PyTorch 如何利用多个损失开展深度神经网络的训练过程【持续更新】

咱们直接进入正题!

def train(model, loss1, loss2, train_dataloader, optimizer_loss1, optimizer_loss2, epoch, writer, device_num):
    model.train()
    device = torch.device("cuda:"+str(device_num))
    correct = 0
    value_loss1 = 0
    value_loss2 = 0
    result_loss = 0
    for data_nnl in train_dataloader:
        data, target = data_nnl
        target = target.long()
        if torch.cuda.is_available():
            data = data.to(device)
            target = target.to(device)

        optimizer_loss1.zero_grad()
        optimizer_loss2.zero_grad()
        output = model(data)
        classifier_output = F.log_softmax(output[1], dim=1)
        value_loss1_batch = loss1(classifier_output, target) //第一个损失项
        value_loss2_batch = loss2(output[0], target) //第二个损失项

        weight_loss2 = 0.005

        result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch

        result_loss_batch.backward()
        optimizer_loss1.step()
        for param in loss2.parameters():
            param.grad.data *= (1. / weight_loss2)
        optimizer_loss2.step()

我这里采用的是两项损失,loss1用于优化网络权重,loss2用于优化中心矢量,二者均是可训练的超参,因此包含两个优化器,如果多个损失项均用于优化网络权重,那么只采用一个优化器即可,如下所示

def train(model, loss1, loss2, train_dataloader, optimizer, epoch, writer, device_num):
    model.train()
    device = torch.device("cuda:"+str(device_num))
    correct = 0
    value_loss1 = 0
    value_loss2 = 0
    result_loss = 0
    for data_nnl in train_dataloader:
        data, target = data_nnl
        target = target.long()
        if torch.cuda.is_available():
            data = data.to(device)
            target = target.to(device)

        optimizer.zero_grad()
        output = model(data)
        classifier_output = F.log_softmax(output[1], dim=1)
        value_loss1_batch = loss1(classifier_output, target) //第一个损失项
        value_loss2_batch = loss2(output[0], target) //第二个损失项

        weight_loss2 = 0.005

        result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch

        result_loss_batch.backward()
        optimizer.step()

详细代码,请翻阅我们的论文,代码已开源,开源链接可查论文摘要。

若该经验贴对您科研、学习有所帮助,欢迎您引用我们的论文。

[1] X. Fu et al., "Semi-Supervised Specific Emitter Identification Method Using Metric-Adversarial Training," in IEEE Internet of Things Journal, vol. 10, no. 12, pp. 10778-10789, 15 June15, 2023, doi: 10.1109/JIOT.2023.3240242.

[2] X. Fu et al., "Semi-Supervised Specific Emitter Identification via Dual Consistency Regularization," in IEEE Internet of Things Journal, doi: 10.1109/JIOT.2023.3281668.

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值