batch_size=4, max_time=5, and labels_length=[5,4,3,2].
import tensorflow as tf
labels = tf.Variable([[4, 3, 1, 2, 5],
[2, 3, 4, 1, 0],
[1, 2, 3, 0, 0],
[5, 4, 0, 0, 0]], tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
idx = tf.where(tf.not_equal(labels, 0))
sparse = tf.SparseTensor(idx, tf.gather_nd(labels, idx), labels.get_shape())
s = sess.run(sparse)
print s.indices
print s.values
print s.dense_shape
[[0 0]
[0 1]
[0 2]
[0 3]
[0 4]
[1 0]
[1 1]
[1 2]
[1 3]
[2 0]
[2 1]
[2 2]
[3 0]
[3 1]]
[4 3 1 2 5 2 3 4 1 1 2 3 5 4]
[4 5]
- ctc例子
https://blog.csdn.net/JackyTintin/article/details/79425866
- search
# tf.reset_default_graph()
import tensorflow as tf
import numpy as np
np.random.seed(1111)
B = 4
x = tf.constant(np.random.normal(loc=0.1, scale=0.3, size=[B,5,16]), dtype=tf.float32)
y = tf.constant(np.random.normal(loc=0.1, scale=0.3, size=[B,10,16]), dtype=tf.float32)
x_len = tf.constant([3, 2, 5, 4])
y_len = tf.constant([6, 5, 10,8])
# x_len = tf.constant([5,])
# y_len = tf.constant([10,])
#based on rad-tts
def get_aligment_matrix(text_embeddings, text_lengths, mel_embeddings, mel_lengths):
text_mask = mask(text_lengths) # [B, N]
mel_mask = mask(mel_lengths)# [B, T]
attn_mask = mel_mask[..., None] * text_mask[:, None, :] # [B, T, N]
keys = tf.transpose(text_embeddings, [0, 2, 1]) #[B, C, N]
queries = tf.transpose(mel_embeddings, [0, 2, 1]) #[B, C, T]
# simplest attention score
attn = (queries[..., None]-keys[:,:,None,:]) ** 2 #[B, C, T, N]
attn = tf.math.exp(-tf.reduce_sum(attn, axis=1, keep_dims=False))*attn_mask #[B, T, N] # .sum(1, keepdim=True)
attn = tf.pad(attn, [[0,0],[0,0],[0,1]], mode='constant')
maxlen = tf.reduce_max(text_lengths[None])
empty_mask = tf.cast(tf.equal(tf.range(0, maxlen+1)[None], text_lengths[:, None]), tf.int32)
empty_prob = 0.001*tf.cast(tf.ones_like(attn), tf.float32)*tf.cast(empty_mask[:,None,:], tf.float32)
align_matrix = tf.add(attn, empty_prob)
softmax_align = tf.nn.softmax(align_matrix, -1)
# [B, T', S], attention alignment
return align_matrix
def mask(lengths: tf.Tensor, maxlen = None):
if maxlen is None:
maxlen = tf.reduce_max(lengths)
# [B, S]
return tf.cast(tf.range(maxlen)[None] < lengths[:, None], tf.float32)
def prepare_ctc_labels(lengths: tf.Tensor, maxlen = None):
if maxlen is None:
maxlen = tf.reduce_max(lengths)
mask = tf.cast(tf.range(maxlen)[None] < lengths[:, None], tf.int32)
labels = tf.tile(tf.range(maxlen)[None], [B, 1]) * mask
# labels = tf.tile(tf.range(maxlen)[None], [B, 1]) * mask + (1-mask)*(maxlen) #pad maxlen as the ctc blank symbol
idx = tf.where(tf.not_equal(labels, 0))
labels_sparse = tf.SparseTensor(idx, tf.gather_nd(labels, idx), tf.cast(tf.shape(labels), tf.int64))
return labels_sparse, labels
def prepare_ctc_label(x):
labels = tf.range(x)[None]
idx = tf.where(tf.not_equal(labels, 0))
labels = tf.SparseTensor(idx, tf.gather_nd(labels, idx),
tf.cast(tf.shape(labels), tf.int64))
return labels
def cond(i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label):
return i < b
def body(i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label):
attn = align_matrix[i:i+1, :y_len[i], :x_len[i]+1] #every sample has its emission matrix shape
label = prepare_ctc_label(x_len[i])
loss = tf.nn.ctc_loss(label,
tf.transpose(attn, [1, 0, 2]),
y_len[i][None], time_major=True)
ctc_loss = tf.concat([ctc_loss, loss], axis=0)
p_label, _ = tf.nn.ctc_greedy_decoder(
tf.transpose(tf.nn.softmax(attn, -1), [1,0,2]),
y_len[i][None],
merge_repeated=False)
# p_label, _ = tf.nn.ctc_beam_search_decoder(
# tf.transpose(tf.nn.softmax(attn, -1), [1,0,2]),
# y_len[i][None],
# beam_width=100,
# top_paths=1,
# merge_repeated=False)
p_label = tf.sparse_to_dense(sparse_indices=p_label[0].indices,
output_shape=p_label[0].dense_shape,
sparse_values=p_label[0].values)
padding = -1*tf.ones(shape=[1,T-tf.shape(p_label)[1]], dtype=tf.int64)
p_label = tf.concat([p_label, padding], axis=-1)
decoded_label = tf.concat([decoded_label, p_label], axis=0)
i = i + 1
return i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label
i = tf.convert_to_tensor(0, dtype=tf.int32)
ctc_loss = tf.constant(np.empty(shape=[0]), dtype=tf.float32)
align_matrix = get_aligment_matrix(x, x_len, y, y_len)
b, T = tf.shape(align_matrix)[0], tf.shape(align_matrix)[1] # batch_size
# decoded_label = tf.cast(tf.range(T)[None], dtype=tf.int64) # why must be int64 ????
decoded_label = tf.constant(np.empty(shape=[0,align_matrix.get_shape()[1]]),
dtype=tf.int64)
i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label \
= tf.while_loop(cond, body,
loop_vars=[i, b, T, align_matrix, x_len, y_len, ctc_loss, decoded_label],\
shape_invariants=[i.get_shape(),
b.get_shape(),
T.get_shape(),
align_matrix.get_shape(),
x_len.get_shape(),
y_len.get_shape(),
tf.TensorShape([None]),
tf.TensorShape([None,None])]) # allow increasing shape
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
b, ctc_loss, align, label = sess.run([b, ctc_loss, align_matrix, decoded_label])
print('ctc_loss: ', ctc_loss)
print('ctc_loss: ', ctc_loss.sum())
print('decoded label: \n', label)
- 对比
num_classes, batch_size, seq_len = 3, 1, 30
import tensorflow as tf
labels = tf.SparseTensor(indices=[[0,0]], values=[0], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))
import torch.nn as nn
import torch
ctcloss = nn.CTCLoss(zero_infinity=True)
inputs = torch.zeros((seq_len,1,num_classes))
target_seq = torch.IntTensor([[0,1]])
loss = ctcloss(inputs,target_seq,input_lengths=[seq_len],target_lengths=[num_classes-1])
print(loss)