现有如下输入:
1. [Batch_size, Seq_len, N_classes]的tensor T
2. [Batch_size, Seq_len]的mask矩阵 M
需求:根据M中的值来mask T,并去掉被mask掉的值
实际场景:主要是现在有一个对话的每个时刻的状态,和角色(0/1),需要取出角色为1的所有utterances的状态,在此基础上选取最后一个角色为1的utterance的状态
问题点:
若使用tf.boolean_mask,则会得到[?, N_classes], ?代表这个方法将每个样本筛选出来的样本压缩成一维,是一个不规则的维度,?数值小于Batch_size * Seq_len
正确操作:
使用tf.ragged.boolean_mask,保留原始的维度,利用不规则tensor特性,得到S=[Batch_size, ?, N_classes], 再使用tf.squeeze(S[:, :-1,:], 1),得到期望结果
注:Ragged Tensor不支持直接index,所以这里使用slice操作,再压缩中间维度。