函数原型
tf.linalg.band_part(
input, num_lower, num_upper, name=None
)
函数说明
band_part函数主要用于处理方形矩阵的副对角线上的元素。以对角线为中心,对副对角线上的元素进行取舍(是否用0填充)。
参数num_lower表示下三角矩阵保留的副对角线的数量,比如num_lower=2表示下三角矩阵从第二条副对角线开始,之后的所有的副对角线的元素全部用0填充。类似的,参数num_upper表示上三角矩阵保留的副对角线的数量。注意,如果为负数,则表示全部保留。
函数使用
>>> a = [[1, 2, 3, 4],
[2, 1, 5, 6],
[3, 5, 1, 7],
[4, 6, 7, 1]]
>>> b = tf.constant(a)
>>> b
<tf.Tensor: shape=(4, 4), dtype=int32, numpy=
array([[1, 2, 3, 4],
[2, 1, 5, 6],
[3, 5, 1, 7],
[4, 6, 7, 1]])>
>>> c = tf.linalg.band_part(b, 1, -1)
>>> c
<tf.Tensor: shape=(4, 4), dtype=int32, numpy=
array([[1, 2, 3, 4],
[2, 1, 5, 6],
[0, 5, 1, 7],
[0, 0, 7, 1]])>
>>> d = tf.linalg.band_part(b, 2, 2)
>>> d
<tf.Tensor: shape=(4, 4), dtype=int32, numpy=
array([[1, 2, 3, 0],
[2, 1, 5, 6],
[3, 5, 1, 7],
[0, 6, 7, 1]])>