pytorch CTCLOSS 降不下來的bug

CTC Loss详解与应用
本文探讨了CTC Loss在深度学习中的应用,特别是在语音识别和序列预测任务中。通过实例演示了如何使用PyTorch实现CTC Loss,并强调了模型输出在训练过程中的正确设备配置,以确保损失值的有效下降。
ctc_loss = nn.CTCLoss()
log_probs = torch.randn(50, 16, 20).log_softmax(2)
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
loss = ctc_loss(log_probs.cpu(), targets, input_lengths, target_lengths)
loss.backward()

切記 loss = ctc_loss(log_probs.cpu(), targets, input_lengths, target_lengths),其中模型輸出的log_probs一定要放在cpu上,如果放在cuda上,那麼loss訓練過程中會下降的特別慢甚至不下降。

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值