咱们直接进入正题!
def train(model, loss1, loss2, train_dataloader, optimizer_loss1, optimizer_loss2, epoch, writer, device_num):
model.train()
device = torch.device("cuda:"+str(device_num))
correct = 0
value_loss1 = 0
value_loss2 = 0
result_loss = 0
for data_nnl in train_dataloader:
data, target = data_nnl
target = target.long()
if torch.cuda.is_available():
data = data.to(device)
target = target.to(device)
optimizer_loss1.zero_grad()
optimizer_loss2.zero_grad()
output = model(data)
classifier_output = F.log_softmax(output[1], dim=1)
value_loss1_batch = loss1(classifier_output, target) //第一个损失项
value_loss2_batch = loss2(output[0], target) //第二个损失项
weight_loss2 = 0.005
result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch
result_loss_batch.backward()
optimizer_loss1.step()
for param in loss2.parameters():
param.grad.data *= (1. / weight_loss2)
optimizer_loss2.step()
我这里采用的是两项损失,loss1用于优化网络权重,loss2用于优化中心矢量,二者均是可训练的超参,因此包含两个优化器,如果多个损失项均用于优化网络权重,那么只采用一个优化器即可,如下所示
def train(model, loss1, loss2, train_dataloader, optimizer, epoch, writer, device_num):
model.train()
device = torch.device("cuda:"+str(device_num))
correct = 0
value_loss1 = 0
value_loss2 = 0
result_loss = 0
for data_nnl in train_dataloader:
data, target = data_nnl
target = target.long()
if torch.cuda.is_available():
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data)
classifier_output = F.log_softmax(output[1], dim=1)
value_loss1_batch = loss1(classifier_output, target) //第一个损失项
value_loss2_batch = loss2(output[0], target) //第二个损失项
weight_loss2 = 0.005
result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch
result_loss_batch.backward()
optimizer.step()
详细代码,请翻阅我们的论文,代码已开源,开源链接可查论文摘要。
若该经验贴对您科研、学习有所帮助,欢迎您引用我们的论文。
[1] X. Fu et al., "Semi-Supervised Specific Emitter Identification Method Using Metric-Adversarial Training," in IEEE Internet of Things Journal, vol. 10, no. 12, pp. 10778-10789, 15 June15, 2023, doi: 10.1109/JIOT.2023.3240242.
[2] X. Fu et al., "Semi-Supervised Specific Emitter Identification via Dual Consistency Regularization," in IEEE Internet of Things Journal, doi: 10.1109/JIOT.2023.3281668.