biaffine分析

下面改动这个代码,方便理解和实验

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

class Biaffine(torch.nn.Module):
def __init__(self, n_in=768, n_out=2, bias_x=True, bias_y=True):
super(Biaffine, self).__init__()
self.n_in = n_in
self.n_out = n_out
self.bias_x = bias_x
self.bias_y = bias_y

self.weight = nn.Parameter(torch.Tensor(n_out, n_in + bias_x, n_in + bias_y))
self.reset_parameters()
def reset_parameters(self):
# 改动这里
nn.init.ones_(self.weight)

def forward(self, x, y):
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)

b = x.shape[0]
o = self.weight.shape[0]

x = x.unsqueeze(1).expand(-1, o, -1, -1)
weight = self.weight.unsqueeze(0).expand(b, -1, -1, -1)
y = y.unsqueeze(1).expand(-1, o, -1, -1)

s = torch.matmul(torch.matmul(x, weight), y.permute((0, 1, 3, 2)))
if s.shape[1] == 1:
s = s.squeeze(dim=1)
return s


# 1 表示分类个数,下面改成2也是在此处哦
model = Biaffine(3, 1, bias_x=False, bias_y=False)
x = torch.arange(12, dtype=torch.float).reshape(2, 2, 3)
y = torch.arange(12, dtype=torch.float).reshape(2, 2, 3)
result = model(x, y)
print(result)

分析:

1. x.shape

1
2
3
4
5
6
7
8
x
Out[56]:
tensor([[[ 0., 1., 2.],
[ 3., 4., 5.]],
[[ 6., 7., 8.],
[ 9., 10., 11.]]])

表示的是sequence_length为2, 那么[0,1,2]表示第一个字,[3,4,5]表示第二个字.

2. x和weight点积

1
2
3
4
5
6
7
8
## 中间过程忽略

torch.matmul(x, weight)
Out[2]:
tensor([[[[ 3., 3., 3.],
[12., 12., 12.]]],
[[[21., 21., 21.],
[30., 30., 30.]]]], grad_fn=<UnsafeViewBackward>)

3. 怎么来的呢

1
2
3
4
5
6

0 * 1 + 1 * 1 + 2 * 1 = 3
3 * 1 + 4 * 1 + 5 * 1 = 12

6 * 1 + 7 * 1 + 8 * 1 = 21
9 * 1 + 10 * 1 + 11 * 1 = 30

4. y.permute((0, 1, 3, 2))

1
2
3
4
5
6
7
8
9
y.permute((0, 1, 3, 2))

Out[3]:
tensor([[[[ 0., 3.],
[ 1., 4.],
[ 2., 5.]]],
[[[ 6., 9.],
[ 7., 10.],
[ 8., 11.]]]])

5. torch.matmul(torch.matmul(x, weight), y.permute((0, 1, 3, 2)))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
torch.matmul(torch.matmul(x, weight), y.permute((0, 1, 3, 2)))
Out[4]:
tensor([[[[ 9., 36.],
[ 36., 144.]]],
[[[441., 630.],
[630., 900.]]]], grad_fn=<UnsafeViewBackward>)


# 怎么来的呢

3 * 0 + 3 * 1 + 3 * 2 = 9
3 * (3 + 4 + 5) = 36
12 * (0 + 1 + 2) = 36
12 * (3 + 4 + 5) = 144



# 发现了没,sequence_length那一维(即每个字)都和其他的字进行点积,并乘以权重,从而获取最终的output。
# 是不是瞬间发现厉害的地方。。。。。

final. 扩展

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 扩展,如果out为2的话

s
Out[2]:
tensor([[[[ 9., 36.], // 分类1
[ 36., 144.]],

[[ 9., 36.], // 分类2
[ 36., 144.]]],

[[[441., 630.],
[630., 900.]],
[[441., 630.],
[630., 900.]]]], grad_fn=<UnsafeViewBackward>)

// 即表示多分类情况下每个分类的输出,通过反向更新weight
// 是不是get到重点了。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值