TensorFlow 构造 attention mask 或 causal mask

def upper_triangle_bias(D, dtype=tf.float32):
    """Create a upper triangle matrix for decoding bias."""
    upper_triangle_DxD = 1 - tf.matrix_band_part(
        tf.ones([D, D], dtype=dtype), -1, 0)
    tensor_1xDxD = tf.expand_dims(upper_triangle_DxD * dtype.min, axis=0)
    return tensor_1xDxD
©️2020 CSDN 皮肤主题: 创作都市 设计师:CSDN官方博客 返回首页