一些mask的操作理解

关于mask的使用,常见的用法在进行padding的时候,例如:

1. 对矩阵获取句子长度

1
2
3
4
5
6
from torch.nn.utils.rnn import pad_sequence

a = [torch.tensor([1,2, 3]), torch.tensor([4,5])]
b = pad_sequence(a, batch_first=True)
mask = b.not_equal(0)
b[mask].split(mask.sum(1).tolist())

2. 计算loss的时候把mask加上

略.

3. 比如三维矩阵操作mask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
a = torch.arange(24).reshape(2, 3, 4)

Out[28]:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])



mask = torch.tril(torch.ones(2, 3)).bool()

Out[29]:
tensor([[ True, False, False],
[ True, True, False]])


# 假设mask为下面的如何理解?
a[mask]

Out[30]:
tensor([[ 0, 1, 2, 3],
[12, 13, 14, 15],
[16, 17, 18, 19]])

简单理解,就是将sequence_length中padding位忽略掉。

4. gather

1
2
3
4
5
6
7
8
9
10
# 例如这句,他在ltp中:https://github.com/HIT-SCIR/ltp/blob/f3d4a25ee2fbb71613f76c99a47e70a5445b8c03/ltp/transformer_rel_linear.py#L58 中出现 
input = torch.gather(input, dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, input.size(-1)))

# word_index的获取方式类似:
input_ids = [112, 4423, 232]
word_index = torch.arange(input_ids.size(0))

# 猜测的意思: 因为有填充位,所以上面把sequence_length中填充位的都不取。

# 但是实际测试下来并不是,看下面例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf8 -*-
#
import torch

model = torch.nn.Embedding(9, 10)

input_embed = torch.tensor([[2, 3, 4, 5, 0], [6, 7, 8, 0, 0]])
word_index = torch.tensor([[0, 1, 2, 3, 0], [0, 1, 2, 0, 0, ]])
out = model(input_embed)
print(out, out.shape)
result = torch.gather(out, dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, out.size(-1)))

print(result, result.shape)


"""
tensor([[[-0.1290, -0.8854, -0.0281, -1.0047, -0.6565, 0.2380, 1.6062,
-1.0218, 0.9187, -0.5830],
[ 1.7823, 0.1437, 0.8367, 0.3261, 0.0991, -0.8338, 1.5731,
2.6733, 0.2048, -0.4198],
[ 0.9965, 2.6325, 1.1463, -0.3047, 0.7547, -1.9135, -1.9450,
0.1363, 1.5608, 1.0028],
[-0.3929, 0.3888, 0.3454, -0.5054, -0.0680, -0.3803, 1.2884,
-1.1461, -0.3259, 0.6795],
[ 0.1853, -0.8628, -1.4179, 1.4490, 0.6766, 0.7106, -0.6956,
-0.2125, 0.2669, -0.0373]],

[[-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500,
-0.2393, 0.8611, 1.2914],
[ 0.6630, 0.7863, -0.2253, -1.5720, -0.4309, -2.0466, -1.0762,
0.5243, -0.3297, -0.0142],
[ 0.0903, -1.0030, 0.1973, 0.9981, 1.2901, -0.5555, -0.2912,
-0.6930, -0.1299, -0.9054],
[ 0.1853, -0.8628, -1.4179, 1.4490, 0.6766, 0.7106, -0.6956,
-0.2125, 0.2669, -0.0373],
[ 0.1853, -0.8628, -1.4179, 1.4490, 0.6766, 0.7106, -0.6956,
-0.2125, 0.2669, -0.0373]]], grad_fn=<EmbeddingBackward>) torch.Size([2, 5, 10])
tensor([[[-0.1290, -0.8854, -0.0281, -1.0047, -0.6565, 0.2380, 1.6062,
-1.0218, 0.9187, -0.5830],
[ 1.7823, 0.1437, 0.8367, 0.3261, 0.0991, -0.8338, 1.5731,
2.6733, 0.2048, -0.4198],
[ 0.9965, 2.6325, 1.1463, -0.3047, 0.7547, -1.9135, -1.9450,
0.1363, 1.5608, 1.0028],
[-0.3929, 0.3888, 0.3454, -0.5054, -0.0680, -0.3803, 1.2884,
-1.1461, -0.3259, 0.6795],
[-0.1290, -0.8854, -0.0281, -1.0047, -0.6565, 0.2380, 1.6062,
-1.0218, 0.9187, -0.5830]],

[[-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500,
-0.2393, 0.8611, 1.2914],
[ 0.6630, 0.7863, -0.2253, -1.5720, -0.4309, -2.0466, -1.0762,
0.5243, -0.3297, -0.0142],
[ 0.0903, -1.0030, 0.1973, 0.9981, 1.2901, -0.5555, -0.2912,
-0.6930, -0.1299, -0.9054],
[-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500,
-0.2393, 0.8611, 1.2914],
[-0.9646, 0.6840, 1.7190, 1.2912, -1.1019, 0.6682, -0.0500,
-0.2393, 0.8611, 1.2914]]], grad_fn=<GatherBackward>) torch.Size([2, 5, 10])

Process finished with exit code 0

"""
1. 比如[2, 3, 4, 5, 0],有一个填充位,那么最后一个0取出来的结果和2取出来的结果一致。
2. 比如[6, 7, 8, 0, 0],有两个填充位,那么最后两个0取出来的结果和6取出来的结果一致。

从而证明ltp中那一行代码错误.

5. masked_fill

好久就忘了,遇到再补充。

这个可以用于~mask填充math.inf,emmm,只想到这么多了。

6. 二维mask -> 三维mask

这个在变成4维(最后一维表示feature)时用到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
mask = torch.tensor([
[1, 1, 1, 0],
[1, 1, 0, 0]
])
mask3d = mask.unsqueeze(-1) & mask.unsqueeze(-2)
print(mask3d)

Out:
tensor([[[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]],

[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]])

7. flatten操作

1
2
3
4
5
6
7
8
9
10
11
12
13
a = torch.arange(12).view(2, 2, 3)
a
Out[39]:
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
a.flatten(end_dim=1)
Out[40]:
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值