【阅读源码】Transformer的mask机制-sequence_mask代码解读

import torch
import numpy as np
import matplotlib.pyplot as plt
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    print(attn_shape)
    print(np.ones(attn_shape))
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    print(subsequent_mask)
    return torch.from_numpy(subsequent_mask) == 0
print(subsequent_mask(5))
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])
print(subsequent_mask(5))
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])

涉及的知识点:

1 np.triu or numpy.triu

  • 对于m*n m<n的矩阵
import numpy as np
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=-1)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=0)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=1)))
#输出
数组的上三角部分:
[[ 1  2  3]
 [ 4  5  6]
 [ 0  8  9]
 [ 0  0 12]
 [ 0  0  0]
 [ 0  0  0]
 [ 0  0  0]]
数组的上三角部分:
[[1 2 3]
 [0 5 6]
 [0 0 9]
 [0 0 0]
 [0 0 0]
 [0 0 0]
 [0 0 0]]
数组的上三角部分:
[[0 2 3]
 [0 0 6]
 [0 0 0]
 [0 0 0]
 [0 0 0]
 [0 0 0]
 [0 0 0]]

矩阵的shape是(7,3),可见k=-1是从第三行(index=2)为下标开始的,依次类推k=0是从第二行(index=1)为下标开始的,k=1是从第一行(index=0)为下标开始的

  • 对于n*n的矩阵
import numpy as np
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=-1)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=0)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=1)))
#输出
数组的上三角部分:
[[1 2 3]
 [4 5 6]
 [0 8 9]]
数组的上三角部分:
[[1 2 3]
 [0 5 6]
 [0 0 9]]
数组的上三角部分:
[[0 2 3]
 [0 0 6]
 [0 0 0]]
  • 对于m*n m<n的矩阵
import numpy as np
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6]], k=-1)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6]], k=0)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6]], k=1)))
#输出
数组的上三角部分:
[[1 2 3]
 [4 5 6]]
数组的上三角部分:
[[1 2 3]
 [0 5 6]]
数组的上三角部分:
[[0 2 3]
 [0 0 6]]

从第一行可以看到对于这个2*3的矩阵,k=-1表示从第三行开始,但是矩阵没有第三行,所以原样输出
其他k的取值还是按照之前陈述的规律输出

2 astype

subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')

作用就是numpy.ndarray类型中数字转换成uint8类型的数据
uint8表示:uint8是8位无符号整型

3 torch.from_numpy(subsequent_mask)

是将ndarray类型的数据转换成tensor类型的数据

4 torch.from_numpy(subsequent_mask) == 0

将每个位置的数==0和零判断是否相等,如果=0,此位置为True,否为False
目的:是将下三角为0,上三角为1的矩阵进行翻转得到,下三角为True,上三角为False

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值