常用的loss函数,以及在训练中的使用

KL 散度

算KL散度的时候要注意前后顺序以及加log

import torhch.nn as nn
d_loss = nn.KLDivLoss(reduction=reduction_kd)(F.log_softmax(y / T, dim=1),
                                     F.softmax(teacher_scores / T, dim=1)) * T * T

蒸馏loss T 为

在这里插入图片描述
在这里插入图片描述

L2 loss

import torch.nn.functional as F
F.mse_loss(teacher_patience.float(), student_patience.float()).half()

做标准化处理

if normalized_patience:
        teacher_patience = F.normalize(teacher_patience, p=2, dim=2)
        student_patience = F.normalize(student_patience, p=2, dim=2)

L2 范数
在这里插入图片描述

CEloss

分类问题

nll_loss = F.cross_entropy(y, labels, reduction=reduction_nll)

在这里插入图片描述

CTCLoss

计算连续(未分割)时间序列和目标序列之间的损失.
torch.nn.CTCLoss(blank=0, reduction=‘mean’, zero_infinity=False)

>>> # Target are to be padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)

图片卷积-》序列做loss
31,1,64,256 -》 31,512,3,65-》31,65,512,3-》polling-》31,65,512》》liner》31,65,103》TNC》65,31,103
在这里插入图片描述

AdaptiveAvgPool2d

自适应池化,最后设置要输出的H和W,B和N(C)不变
在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值