tensorflow去上下三角矩阵:tf.linalg.band_part

本文介绍了TensorFlow中的tf.linalg.band_part函数,用于从矩阵中提取或填充上下三角部分,详细解释了函数作用、参数含义,并通过示例进行说明,适用于理解和实现Transformer的掩码机制。
摘要由CSDN通过智能技术生成

  学习transformer的掩码机制时遇到了这个函数,因此记录一下。
函数定义:

tf.linalg.band_part(
    input,
    num_lower,
    num_upper,
    name=None
)

作用:
  以对角线为中心,取它的副对角线部分,其他部分用0填充。

参数:
  先解释一下副对角线,即矩阵中除了主对角线以外的其它对角线。

  • input:输入的张量。
  • num_lower:下三角矩阵保留的副对角线数量,取值为负数时全部保留,为0时全为0。
  • num_upper:上三角矩阵保留的副对角线数量,取值为负数时全部保留,为0时全为0。

示例:

import tensorflow as tf
tf.enable_eager_execution()
a=tf.constant( [[1,  2,  3, 4],
				[5,  6,  7, 8],
				[9, 10,  11,12],
                [13, 14, 15,16]],dtype=tf.float32)
b=tf.linalg.band_part(a,2,0
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值