文本分类半监督学习(七)

2021SC@SDUSC

损失函数的计算

#训练损失函数,用在训练时
train_criterion = SemiLoss()
#交叉熵损失, 用在验证集和测试集, 是模型训完完成后的,使用交叉熵进行计算损失
criterion = nn.CrossEntropyLoss()

#使用SemiLoss计算损失
class SemiLoss(object):
def __call__(self, outputs_x, targets_x, outputs_u, targets_u, outputs_u_2, epoch, mixed=1):
"""
半监督损失函数
:param outputs_x: 模型输出的x
:param targets_x: 真实的x
:param outputs_u: 模型输出的无标签的x
:param targets_u: 真实的无标签的x
:param outputs_u_2: 模型输出的无标签x_2
:param epoch: 迭代次数
:param mixed: 是否是混合过的输出
:return:
"""
if args.mix_method == 0 or args.mix_method == 1:
#有监督的x的损失
Lx = - \
torch.mean(torch.sum(F.log_softmax(
outputs_x, dim=1) * targets_x, dim=1))
#无监督的x输出的概率值
probs_u = torch.softmax(outputs_u, dim=1)
#论文中公式显示的kl散度, batch mean 批次均值作为统计计算KL散度
Lu = F.kl_div(probs_u.log(), targets_u, None, None, 'batchmean')
#计算hinge Loss 折页损失 max(0,1-(wTxi +b)yi)
Lu2 = torch.mean(torch.clamp(torch.sum(-F.softmax(outputs_u, dim=1)
* F.log_softmax(outputs_u, dim=1), dim=1) - args.margin, min=0))

elif args.mix_method == 2:
if mixed == 0:
Lx = - \
torch.mean(torch.sum(F.logsigmoid(
outputs_x) * targets_x, dim=1))

probs_u = torch.softmax(outputs_u, dim=1)

Lu = F.kl_div(probs_u.log(), targets_u,
None, None, 'batchmean')

Lu2 = torch.mean(torch.clamp(args.margin - torch.sum(
F.softmax(outputs_u_2, dim=1) * F.softmax(outputs_u_2, dim=1), dim=1), min=0))
else:
Lx = - \
torch.mean(torch.sum(F.log_softmax(
outputs_x, dim=1) * targets_x, dim=1))

probs_u = torch.softmax(outputs_u, dim=1)
Lu = F.kl_div(probs_u.log(), targets_u,
None, None, 'batchmean')

Lu2 = torch.mean(torch.clamp(args.margin - torch.sum(
F.softmax(outputs_u, dim=1) * F.softmax(outputs_u, dim=1), dim=1), min=0))

return Lx, Lu, args.lambda_u * linear_rampup(epoch), Lu2, args.lambda_u_hinge * linear_rampup(epoch)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值