在看gpt2源码时,有这样一段:
def attention_mask(nd, ns, *, dtype):
"""1's in the lower triangle, counting from the lower right corner.
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
"""
i = tf.range(nd)[:,None]
j = tf.range(ns)
m = i >= j - ns + nd
return tf.cast(m, dtype)
之前没见过这个函数 tf.matrix_band_part,搜到的基本都是源码的注释,这个注释也没看懂:
def matrix_band_part(input, num_lower, num_upper, name=None):
r"""Copy a tensor setting everything outside a central band in each innermost matrix
to zero.
The `band` part is computed as follows:
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
tensor with the same shape where
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
The indicator function
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
(num_upper < 0 || (n-m) <= num_upper)`.
For example:
```
# if 'input' is [[ 0, 1, 2, 3]
[-1, 0, 1, 2]
[-2, -1, 0, 1]
[-3, -2, -1, 0]],
tf.matrix_band_part(input, 1, -1) ==> [[