新版本,tf.matrix_band_part挪到了tf.linalg.band_part,它的主要功能是以对角线为中心,取它的副对角线部分,其他部分设置为0,视觉就像是一个band(带),tf.linalg.band_part主要有三个参数
input:要输入的张量tensor.
num_lower:下三角矩阵保留的副对角线数量,从主对角线开始计算,相当于下三角的带宽。取值为负数时,则全部保留,矩阵不变。
num_upper:上三角矩阵保留的副对角线数量,从主对角线开始计算,相当于上三角的带宽。取值为负数时,则全部保留,矩阵不变。
import tensorflow as tf
tf.enable_eager_execution()
a=tf.constant( [[ 1, 1, 2, 3],[-1, 2, 1, 2],[-2, -1, 3, 1],
[-3, -2, -1, 5]],dtype=tf.float32)
b=tf.linalg.band_part(a,3,0)
c=tf.linalg.band_part(a,3,1)
print(b)
tf.Tensor(
[[ 1. 0. 0. 0.]
[-1. 2. 0. 0.]
[-2. -1. 3. 0.]
[-3. -2. -1. 5.]], shape=(4, 4), dtype=float32)
print(c)
tf.Tensor(
[[ 1. 1. 0. 0.]
[-1. 2. 1. 0.]
[-2. -1. 3. 1.]
[-3. -2. -1. 5.]], shape=(4, 4), dtype=float32)