最近在看论文代码的时候偶然看到了一个之前没有见到过的函数,现介绍如下(如有不对欢迎指正):
- tf.linalg 类:
用于线性计算的类,下面包括了好多线性操作的类,tf.linalg.LinearOperator()就是其中之一,还包括了如转置,奇异值分解,计算张量的迹等等关于张量线性运算的一些函数。 - tf.linalg.LinearOperatorLowerTriangular()
顾名思义这个类是针对下三角张量的一些列操作。先看一个例子
a = tf.constant([[1.0, 3.0, 3.0], [2.0, 3.0, 4.0], [1.0, 3.0, 4.0]])
tril = tf.linalg.LinearOperatorLowerTriangular(a)
with tf.Session() as sess:
print(sess.run(tril.to_dense()))
输出结果为:
[[1. 0. 0.]
[2. 3. 0.]
[1. 3. 4.]]
可见这个函数把输入的张亮转换为了下三角矩阵,但是如果是不规则的张亮会出现什么的情况呢?
首先是shape = [4,3] 的张量:
a = tf.constant([[1.0, 3.0, 3.0], [2.0, 3.0, 4.0], [1.0, 3.0, 4.0], [3, 4, 5]])
tril = tf.linalg.LinearOperatorLowerTriangular(a)
with tf.Session() as sess:
print(sess.run(tril.to_dense()))
输出结果是:
[[1. 0. 0.]
[2. 3. 0.]
[1. 3. 4.]
[3. 4. 5.]]