主要借助tf.diag_part
和tf.matrix_diag
两个方法来将方阵对角线置0.
- tf.diag_part
函数返回tensor的对角线元素in : inputs = [[1,2,3,4], [2,3,4,5], [3,4,5,6], [4,5,6,7]] in : sess.run(tf.diag_part(inputs)) out: array([1, 3, 5, 7], dtype=int32)
- tf.matrix_diag
构造对角线矩阵# 对角线元素 in : x = tf.diag_part(inputs) in: matrix = tf.matrix_diag(x) # 原矩阵减去对角矩阵,即可实现对角线元素置0 in: sess.run(inputs- matrix) out: array([[0, 2, 3, 4], [2, 0, 4, 5], [3, 4, 0, 6], [4, 5, 6, 0]], dtype=int32)