tf.linalg.band_part(input,num_lower,num_upper),
此函数的含义:以主对角线为中心,取它的副对角线部分,其他部分用0 填充
input:输入的张量
num_lower:从主对角线开始计算,下三角矩阵保留的副对角线数量,取值为负数时,则全部保留
num_upper:从主对角线开始计算,上三角矩阵保留的副对角线数量,取值为负数时,则全部保留
import tensorflow as tf
a=tf.constant( [[ 1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1],
[1, 1, 1, 1]],dtype=tf.float32)
#没有保留的行用0填充
b=tf.linalg.band_part(a,3,1)#下三角矩阵保留3行,上三角保留1行,
c=tf.linalg.band_part(a,2,1)#下三角矩阵保留2行,上三角保留1行
d=tf.linalg.band_part(a,1,1)#下三角保留1行,上三角保留1行
e=tf.linalg.band_part(a,-1,1)#下三角全部保留,上三角保留一行
print(a)
print(b)
print(c)
print(d)
print(e)
结果
tf.Tensor(
[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]], shape=(4, 4), dtype=float32)
tf.Tensor(
[[1. 1. 0. 0.]
[1. 1. 1. 0.]
[1. 1. 1. 1.]<