【debug】pytorch CTC_Loss为nan

1. feature中有nan值

有次max_pool2d参数设计错误出现了这种情况
可以通过 print(feature.max()) 看feature的最大值

2. target length有0值

现在pytorch中有自带的ctcloss其用法


>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()

其中:

  • target.shape = target_lengths.sum()
    注意: 这里的target_lengths中如果有 ‘0’ 则loss为nan。表示一张图片中没有字符,一个字符都没有
    解决方案: 在dataset中过滤掉len(label)==0的图片
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中的CTCLoss是指Connectionist Temporal Classification Loss,它是一种用于解决神经网络标签和输出不对齐问题的方法。CTCLoss的优点是不需要强制对齐标签且标签可以是可变长度的。它主要应用于场景文本识别、语音识别和手写字识别等工程场景。在PyTorch 1.0.x版本内,已经内置了CTCLoss接口,可以直接使用。下面是一个使用CTCLoss的代码示例: ```python import torch import torch.nn as nn ctc_loss = nn.CTCLoss() log_probs = torch.randn(50, 16, 20).log_softmax(2).requires_grad_() targets = torch.randint(1, 20, (16, 30), dtype=torch.long) input_lengths = torch.full((16,), 50, dtype=torch.long) target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) loss.backward() ``` 在这个示例中,我们首先创建了一个CTCLoss实例,然后生成了一些随机的log probabilities作为网络的输出。接着,我们生成了一些随机的目标标签和输入长度以及目标长度。最后,我们使用CTCLoss计算了损失,并进行了反向传播。\[2\] 在创建CTCLoss实例时,可以通过设置参数来自定义一些属性。例如,可以使用`blank`参数来指定空白符的序号,`reduction`参数来指定损失的计算方式。\[3\] 希望这个回答对你有帮助! #### 引用[.reference_title] - *1* [如何使用pytorch内置torch.nn.CTCLoss的方法&&车牌识别应用](https://blog.csdn.net/CSDNwei/article/details/120223026)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [pytorch的torch.nn.CTCLoss方法](https://blog.csdn.net/benben044/article/details/125130411)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Pytorch中的CTC loss](https://blog.csdn.net/fidbdiej/article/details/124587812)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值