warpctc

  1. warpctc_pytorch
import torch
import warpctc_pytorch as warp_ctc
from torch.autograd import Function
from torch.nn import Module

from ._warp_ctc import *

__version__ = '0.1.1'


def _assert_no_grad(tensor):
    assert not tensor.requires_grad, \
        "gradients only computed for acts - please " \
        "mark other tensors as not requiring gradients"


class _CTC(Function):
    @staticmethod
    def forward(ctx, acts, labels, act_lens, label_lens, size_average=False,
                length_average=False, blank=0, reduce=True):
        is_cuda = True if acts.is_cuda else False
        acts = acts.contiguous()
        loss_func = warp_ctc.gpu_ctc if is_cuda else warp_ctc.cpu_ctc
        grads = torch.zeros(acts.size()).type_as(acts)
        minibatch_size = acts.size(1)
        costs = torch.zeros(minibatch_size).cpu()
        loss_func(acts,
                  grads,
                  labels,
                  label_lens,
                  act_lens,
                  minibatch_size,
                  costs,
                  blank)

        if reduce:
            costs = torch.FloatTensor([costs.sum()])

            if length_average:
                # Compute the avg. log-probability per batch sample and frame.
                total_length = torch.sum(act_lens).item()
                grads = grads / total_length
                costs = costs / total_length
            elif size_average:
                # Compute the avg. log-probability per batch sample.
                grads = grads / minibatch_size
                costs = costs / minibatch_size
        else:
            # Make the costs size be B x 1, then grad_output is also B x 1
            # Thus the `grad_output' in backward() is broadcastable
            costs = costs.unsqueeze(1)

        ctx.grads = grads
        return costs

    @staticmethod
    def backward(ctx, grad_output):
        return ctx.grads * grad_output.to(ctx.grads.device), None, None, None, None, None, None, None

class CTCLoss(Module):
    """
    Parameters:
        size_average (bool): normalize the loss by the batch size
            (default: `False`)
        length_average (bool): normalize the loss by the total number of frames
            in the batch. If `True`, supersedes `size_average`
            (default: `False`)
        reduce (bool): average or sum over observation for each minibatch.
            If `False`, returns a loss per batch element instead and ignores `average` options.
            (default: `True`)
    """
    def __init__(self, blank=0, size_average=False, length_average=False, reduce=True):
        super(CTCLoss, self).__init__()
        self.ctc = _CTC.apply
        self.blank = blank
        self.size_average = size_average
        self.length_average = length_average
        self.reduce = reduce

    def forward(self, acts, labels, act_lens, label_lens):
        """
        acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        label_lens: Tensor of (batch) containing label length of each example
        """
        assert len(labels.size()) == 1  # labels must be 1 dimensional
        _assert_no_grad(labels)
        _assert_no_grad(act_lens)
        _assert_no_grad(label_lens)
        return self.ctc(acts, labels, act_lens, label_lens, self.size_average,
                        self.length_average, self.blank, self.reduce)


   #acts=(torch.Size([1183, 32, 214])),
   #labels=(tensor([103, 182, 162,  ..., 174, 106,  60], dtype=torch.int32),len=2124),
   #act_lens=(tensor([1183, 1148, 1092, 1055, 1017, 1017, 1014, 1009,  984,  943,  942,  942,
   #          936,  935,  932,  911,  910,  892,  872,  871,  865,  845,  843,  843,
  #          826,  826,  824,  792,  777,  755,  736,  731], dtype=torch.int32),len=32)
   #(label_lens=tensor([72, 80, 88, 86, 72, 66, 74, 64, 76, 58, 70, 76, 64, 60, 68, 52, 68, 64,
   # 74, 50, 68, 62, 74, 62, 64, 64, 62, 54, 62, 56, 62, 52],
   # dtype=torch.int32),sum=a+b+...=2124,len=32)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值