tf.sequence_mask()函数
sequence_mask(
lengths, # 原始序列长度
maxlen=None, # 最大序列长度
dtype=tf.bool, # 类型是bool(True,Fales),也可以是float(1 or 0)
name=None
)
return mask类型数据,默认类型为bool,返回mask张量
例子
dtype=tf.bool时,返回的是True,Fales的张量
dtype=tf.float时,返回的是1,0的张量
只给定lengths时
lenght = 3
mask_data = tf.sequence_mask(lengths=lenght)
# 输出结果是长度为3的array,前四个True
array([ True, True, True])
给定lengths和maxlen时(dtype不给定,默认为bool)
lenght = 3
maxlen = 5
mask_data = tf.sequence_mask(lengths=lenght,maxlen=maxlen )
# 输出结果是长度为5的array,前三个True
array([ True, True, True, False, False])
定义lengths,maxlen,dtype时
lenght = 3
maxlen = 5
dtype = tf.float32
mask_data = tf.sequence_mask(lengths=lenght,maxlen=maxlen ,dtype=dtype )
# 输出结果是长度为5的array,前三个1.0后俩个为0.0
array([1., 1., 1., 1., 0., 0.], dtype=float32)