ctc_loss = nn.CTCLoss()
log_probs = torch.randn(50, 16, 20).log_softmax(2)
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
loss = ctc_loss(log_probs.cpu(), targets, input_lengths, target_lengths)
loss.backward()
切記 loss = ctc_loss(log_probs.cpu(), targets, input_lengths, target_lengths),其中模型輸出的log_probs一定要放在cpu上,如果放在cuda上,那麼loss訓練過程中會下降的特別慢甚至不下降。
CTC Loss详解与应用
本文探讨了CTC Loss在深度学习中的应用,特别是在语音识别和序列预测任务中。通过实例演示了如何使用PyTorch实现CTC Loss,并强调了模型输出在训练过程中的正确设备配置,以确保损失值的有效下降。

被折叠的 条评论
为什么被折叠?



