tf.nn.ctc

该博客探讨了在深度学习中,CTC(Connectionist Temporal Classification)损失函数如何用于解决不固定长度序列的匹配问题,并展示了在注意力机制下如何构建对齐矩阵和CTC损失。通过实例,解释了CTC损失的计算以及如何使用CTC损失进行解码,同时对比了TensorFlow和PyTorch中CTC损失的实现方式。
摘要由CSDN通过智能技术生成

batch_size=4, max_time=5, and labels_length=[5,4,3,2].

import tensorflow as tf
labels = tf.Variable([[4, 3, 1, 2, 5],
                      [2, 3, 4, 1, 0],
                      [1, 2, 3, 0, 0],
                      [5, 4, 0, 0, 0]], tf.int32)
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  idx = tf.where(tf.not_equal(labels, 0))
  sparse = tf.SparseTensor(idx, tf.gather_nd(labels, idx), labels.get_shape())
  s = sess.run(sparse)
  print s.indices
  print s.values
  print s.dense_shape

[[0 0]
[0 1]
[0 2]
[0 3]
[0 4]
[1 0]
[1 1]
[1 2]
[1 3]
[2 0]
[2 1]
[2 2]
[3 0]
[3 1]]
[4 3 1 2 5 2 3 4 1 1 2 3 5 4]
[4 5]

  1. ctc例子
https://blog.csdn.net/JackyTintin/article/details/79425866
  1. search
# tf.reset_default_graph()
import tensorflow as tf
import numpy as np
np.random.seed(1111)
B = 4
x = tf.constant(np.random.normal(loc=0.1, scale=0.3, size=[B,5,16]), dtype=tf.float32)
y = tf.constant(np.random.normal(loc=0.1, scale=0.3, size=[B,10,16]), dtype=tf.float32)
x_len = tf.constant([3, 2, 5, 4])
y_len = tf.constant([6, 5, 10,8])
# x_len = tf.constant([5,])
# y_len = tf.constant([10,])
#based on rad-tts
def get_aligment_matrix(text_embeddings, text_lengths, mel_embeddings, mel_lengths):
    text_mask = mask(text_lengths) # [B, N]
    mel_mask = mask(mel_lengths)# [B, T]
    attn_mask = mel_mask[..., None] * text_mask[:, None, :] # [B, T, N]

    keys = tf.transpose(text_embeddings, [0, 2, 1]) #[B, C, N]
    queries = tf.transpose(mel_embeddings, [0, 2, 1]) #[B, C, T]
    # simplest attention score
    attn = (queries[..., None]-keys[:,:,None,:]) ** 2 #[B, C, T, N]
    attn = tf.math.exp(-tf.reduce_sum(attn, axis=1, keep_dims=False))*attn_mask  #[B, T, N] # .sum(1, keepdim=True)
    attn = tf.pad(attn, [[0,0],[0,0],[0,1]], mode='constant')
    
    maxlen = tf.reduce_max(text_lengths[None])
    empty_mask = tf.cast(tf.equal(tf.range(0, maxlen+1)[None], text_lengths[:, None]), tf.int32)
    empty_prob = 0.001*tf.cast(tf.ones_like(attn), tf.float32)*tf.cast(empty_mask[:,None,:], tf.float32)
    align_matrix = tf.add(attn, empty_prob)
    softmax_align = tf.nn.softmax(align_matrix, -1)
    # [B, T', S], attention alignment
    return align_matrix
    
def mask(lengths: tf.Tensor, maxlen = None):
    if maxlen is None:
        maxlen = tf.reduce_max(lengths)
    # [B, S]
    return tf.cast(tf.range(maxlen)[None] < lengths[:, None], tf.float32)

def prepare_ctc_labels(lengths: tf.Tensor, maxlen = None):
    if maxlen is None:
        maxlen = tf.reduce_max(lengths)
    mask = tf.cast(tf.range(maxlen)[None] < lengths[:, None], tf.int32)
    labels = tf.tile(tf.range(maxlen)[None], [B, 1]) * mask
#     labels = tf.tile(tf.range(maxlen)[None], [B, 1]) * mask + (1-mask)*(maxlen) #pad maxlen as the ctc blank symbol
    idx = tf.where(tf.not_equal(labels, 0))
    labels_sparse = tf.SparseTensor(idx, tf.gather_nd(labels, idx), tf.cast(tf.shape(labels), tf.int64))
    return labels_sparse, labels    

def prepare_ctc_label(x):
    labels = tf.range(x)[None]
    idx = tf.where(tf.not_equal(labels, 0))
    labels = tf.SparseTensor(idx, tf.gather_nd(labels, idx),
                             tf.cast(tf.shape(labels), tf.int64))
    return labels

def cond(i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label):
    return i < b

def body(i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label):
    attn = align_matrix[i:i+1, :y_len[i], :x_len[i]+1] #every sample has its emission matrix shape
    label = prepare_ctc_label(x_len[i])
    loss = tf.nn.ctc_loss(label,
                       tf.transpose(attn, [1, 0, 2]),
                       y_len[i][None], time_major=True)
    ctc_loss = tf.concat([ctc_loss, loss], axis=0)
    p_label, _ = tf.nn.ctc_greedy_decoder(
                            tf.transpose(tf.nn.softmax(attn, -1), [1,0,2]), 
                            y_len[i][None], 
                            merge_repeated=False)
#     p_label, _ = tf.nn.ctc_beam_search_decoder(
#                             tf.transpose(tf.nn.softmax(attn, -1), [1,0,2]),
#                             y_len[i][None], 
#                             beam_width=100,
#                             top_paths=1,
#                             merge_repeated=False)

    p_label = tf.sparse_to_dense(sparse_indices=p_label[0].indices,
                               output_shape=p_label[0].dense_shape,
                               sparse_values=p_label[0].values)
    
    padding = -1*tf.ones(shape=[1,T-tf.shape(p_label)[1]], dtype=tf.int64)
    p_label = tf.concat([p_label, padding], axis=-1)
    decoded_label = tf.concat([decoded_label, p_label], axis=0)
    i = i + 1
    return i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label

i = tf.convert_to_tensor(0, dtype=tf.int32)
ctc_loss = tf.constant(np.empty(shape=[0]), dtype=tf.float32)
align_matrix = get_aligment_matrix(x, x_len, y, y_len)
b, T = tf.shape(align_matrix)[0], tf.shape(align_matrix)[1] # batch_size
# decoded_label = tf.cast(tf.range(T)[None], dtype=tf.int64) # why must be int64 ????
decoded_label = tf.constant(np.empty(shape=[0,align_matrix.get_shape()[1]]), 
                            dtype=tf.int64)
i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label \
            = tf.while_loop(cond, body,
                loop_vars=[i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label],\
                shape_invariants=[i.get_shape(),
                                  b.get_shape(),
                                  T.get_shape(),
                                  align_matrix.get_shape(),
                                  x_len.get_shape(),
                                  y_len.get_shape(),
                                  tf.TensorShape([None]),
                                  tf.TensorShape([None,None])]) # allow increasing shape
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    b, ctc_loss, align, label = sess.run([b, ctc_loss, align_matrix, decoded_label])

print('ctc_loss: ', ctc_loss)
print('ctc_loss: ', ctc_loss.sum())
print('decoded label: \n', label)
  1. 对比
num_classes, batch_size, seq_len = 3, 1, 30
import tensorflow as tf
labels = tf.SparseTensor(indices=[[0,0]], values=[0], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))

import torch.nn as nn
import torch
ctcloss = nn.CTCLoss(zero_infinity=True)
inputs = torch.zeros((seq_len,1,num_classes))
target_seq = torch.IntTensor([[0,1]])
loss = ctcloss(inputs,target_seq,input_lengths=[seq_len],target_lengths=[num_classes-1])
print(loss)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值