学习transformer的掩码机制时遇到了这个函数,因此记录一下。
函数定义:
tf.linalg.band_part(
input,
num_lower,
num_upper,
name=None
)
作用:
以对角线为中心,取它的副对角线部分,其他部分用0填充。
参数:
先解释一下副对角线,即矩阵中除了主对角线以外的其它对角线。
- input:输入的张量。
- num_lower:下三角矩阵保留的副对角线数量,取值为负数时全部保留,为0时全为0。
- num_upper:上三角矩阵保留的副对角线数量,取值为负数时全部保留,为0时全为0。
示例:
import tensorflow as tf
tf.enable_eager_execution()
a=tf.constant( [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11,12],
[13, 14, 15,16]],dtype=tf.float32)
b=tf.linalg.band_part(a,2,0