padding_mask实战
-
在使用 Transformer 模型时,padding_mask 是非常重要的一部分。它用于在处理变长序列时,标识出填充位置(padding)的掩码,以避免模型将填充的零值考虑在内。
-
以下是关于 padding_mask 的一些详细信息和如何在代码中使用它的示例:
什么是 Padding Mask?
Padding Mask 是一个布尔掩码矩阵,通常用于注意力机制中,来指示哪些位置是填充的。模型在计算注意力时会忽略这些填充位置,以避免对无意义的填充值进行处理,从而提高模型的有效性。
如何使用 Padding Mask?
在 Transformer 中,padding_mask 可以在多头自注意力(Multi-Head Self-Attention)层中使用。具体来说,它可以通过将填充位置的注意力权重设为负无穷来实现,从而使得这些位置在 softmax 操作之后的权重为零。
import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.tensor([[ 2, 4556, 4573, 4569, 4842, 9210, 4675, 4790, 4590, 4668, 4684, 4642,
4979, 5, 3, 1, 1, 1],
[ 2, 4726, 5209, 4600, 4561, 4572, 4663, 5, 3, 1, 1, 1,
1, 1, 1, 1, 1, 1],
[ 2, 4567, 4662, 15, 4556, 4857, 4559, 4556, 4615, 15, 4603, 4674,
4569, 4659, 5, 3, 1, 1],
[ 2, 4556, 6020, 4562, 6, 4563, 4694, 4633, 7903, 4581, 4683, 4560,
4563, 5314, 4833, 4664, 5, 3],
[ 2, 4591, 4578, 6490, 4609, 4570, 4557, 8561, 5, 3, 1, 1,
1, 1, 1, 1, 1, 1],
[ 2, 4556, 4562, 4870, 4575, 4563, 7318, 15, 5222, 4680, 5599, 6106,
4628, 5, 3, 1, 1, 1],
[ 2, 4558, 4565, 6, 4768, 4588, 4585, 4581, 394, 6, 4566, 5227,
5, 3, 1, 1, 1, 1],
[ 2, 4556, 5062, 4585, 4571, 2647, 941, 4592, 0, 4944, 5, 3,
1, 1, 1, 1, 1, 1],
[ 2, 4556, 4714, 4573, 4560, 4758, 4661, 4560, 7566, 4559, 0, 4606,
21, 4557, 7764, 4658, 5, 3],
[ 2, 4556, 4615, 4574, 6, 4557, 0, 4588, 4579, 21, 4563, 4582,
5, 3, 1, 1, 1, 1]])
x.shape
torch.Size([10, 18])
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
src_mask = get_pad_mask(x, 1)
src_mask
tensor([[[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, False,
False, False, False, False, False, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, False, False]],
[[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True]],
[[ True, True, True, True, True, True, True, True, True, True,
False, False, False, False, False, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, False, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, True,
True, True, False, False, False, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True]],
[[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, False, False, False, False]]])
src_mask.shape
torch.Size([10, 1, 18])
batch_size = 10
seq_length = 18
embedding_size = 16
x_word_embedding = torch.randn([batch_size,seq_length,embedding_size])#在词表里是一个随机生成词表大小的权重矩阵,然后每个序号在里面查询。
x_word_embedding.shape
torch.Size([10, 18, 16])
for i in range(x.shape[0]):
for j in range(x.shape[1]):
if x[i][j] == 1:
x_word_embedding[i][j] = 0
x_word_embedding[0]
tensor([[ 0.5953, -0.7687, -0.1396, 0.0551, -0.1316, 2.1182, 0.3460, -0.2474,
-0.7631, 0.7371, 0.2647, -0.9265, -0.0795, -0.9688, -1.9949, -1.0153],
[-1.8476, 0.1382, -1.7616, 0.7693, 0.2297, -0.4240, -0.2053, -1.6727,
-0.9341, -0.6124, -0.3068, -0.9722, -0.5115, -1.0359, -0.7867, 0.5991],
[-0.6086, 0.8565, -0.5559, 2.0990, -0.5317, -0.0794, 0.4806, 0.3160,
1.0327, -0.2135, -0.1543, 1.6888, 1.1976, -1.2593, -0.5984, -0.3219],
[-1.2501, 0.6738, 1.1699, 0.3973, 0.0653, -1.2225, -0.2438, -0.6155,
-0.6204, -1.8898, -0.5140, -2.0530, 0.7663, 0.4650, -1.0250, -1.1067],
[-0.5502, 0.9821, 1.2645, -1.5540, 0.8937, 0.3531, -1.0345, 1.5634,
-0.4247, 0.8320, -1.1106, 1.2879, 0.2906, 2.3099, -1.6235, -0.9464],
[-0.2324, -0.3651, 0.4145, 0.9106, 0.1187, -0.2084, 0.2424, 1.0556,
-2.1255, 1.6913, -1.0605, 1.6755, 1.0412, 0.2087, 1.6073, 0.1279],
[ 0.2160, 0.7791, -0.4280, -0.0124, -2.2770, -0.9033, -0.1782, -1.2013,
0.9789, 0.2439, 0.4394, -0.5341, -0.5609, 1.8852, -0.3370, -0.1769],
[ 2.1392, 1.3648, 0.1819, 0.8782, -1.5937, -0.3102, -0.0304, -0.2514,
-0.7271, 1.2352, -0.7463, 0.9655, 1.7330, -0.0797, 1.7071, 0.0998],
[-0.3247, 0.2992, -1.9922, 0.1976, -0.1732, 1.1075, -3.4594, -0.1730,
1.4339, 0.9275, 0.0803, 0.4825, -1.1535, -1.1750, 1.0148, 0.3328],
[ 0.5756, -0.5909, -1.0873, -0.5838, 0.1695, 0.7564, -0.1989, -1.0383,
-0.1502, 1.5897, -0.5680, -3.2284, -0.8693, -2.0997, -0.1233, 1.1241],
[ 0.5985, -0.6717, -0.4839, -1.4869, 0.2176, -0.4596, 2.2802, 0.3979,
-0.5376, -0.5777, 1.2366, 0.1043, -0.6600, 0.1091, -1.1323, 0.2859],
[ 1.3520, -0.4221, 0.8972, -1.2319, 0.2628, -0.3468, 0.4029, 0.3405,
-1.1161, 0.5008, 1.0744, -1.6060, 2.6297, 1.6144, 0.2159, 0.6276],
[ 0.3464, -1.5550, -1.5556, -0.3224, 0.6951, 1.0998, 1.8592, 0.7744,
0.9877, 0.7624, 2.5627, 0.5629, -1.3892, 0.4832, -0.1809, -0.3011],
[-0.6122, -0.1236, 0.1143, -0.4748, -1.1046, -1.6141, -1.2581, -0.6381,
-1.3857, -1.8965, 0.4942, 1.1739, 1.0684, -0.9261, -0.6500, -0.2302],
[ 0.3307, -0.2460, 1.9470, 0.1465, -0.9497, 0.3327, -0.8901, 1.1441,
-0.4879, 1.0638, -1.4439, -2.0664, -0.8945, 0.5613, 0.7411, -0.2546],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
x_postion_embedding = torch.randn([18,16]).unsqueeze(0)
x_postion_embedding
tensor([[[ 9.0368e-01, 5.8748e-01, -1.1009e+00, -1.0503e-01, -2.1601e+00,
1.5621e+00, 4.6033e-01, 1.2967e+00, -2.1629e+00, -7.8287e-02,
-9.9411e-01, 7.6217e-01, 4.9960e-02, 4.0106e-01, -2.0348e+00,
4.9369e-01],
[ 1.0890e+00, 2.2973e+00, 9.6957e-01, 1.3043e+00, 6.0618e-02,
1.9305e+00, 3.8956e-02, -5.6996e-01, -1.2326e+00, -9.9662e-01,
-6.2974e-01, -1.7874e+00, 2.6497e-01, -3.9136e-01, -1.5453e+00,
-5.2771e-01],
[-8.4272e-01, -1.0274e+00, 3.9564e-01, 2.3713e-02, -8.5774e-02,
-1.9809e+00, 6.5302e-01, -9.5564e-02, -1.3990e-01, -2.4967e-01,
-1.3083e-01, 1.2046e+00, 9.8682e-01, -8.6967e-01, 1.8375e-01,
4.2060e-02],
[ 1.3766e+00, -1.2898e+00, -5.6401e-01, 8.1396e-01, 8.3389e-01,
-1.6490e+00, -1.0349e+00, 3.0295e-01, 3.5253e-01, 1.4929e+00,
5.5201e-01, 1.2799e+00, -4.9722e-01, -1.5553e-01, 3.7034e-01,
1.8389e-01],
[ 5.6875e-01, -1.9672e+00, -1.7386e+00, 7.6477e-01, -1.7213e+00,
-9.3602e-01, 9.6785e-01, 4.8722e-01, 6.3193e-02, 3.8820e-01,
-5.6842e-02, 5.4492e-01, 8.9291e-02, -7.8848e-01, -7.5108e-01,
-2.1830e-01],
[-1.3706e+00, -3.7459e-01, -1.4454e+00, -6.9120e-02, -8.5179e-03,
1.0005e+00, -7.9239e-01, 3.5959e-01, -5.8618e-01, 1.3782e+00,
-1.5627e-01, -2.4377e-01, 5.0945e-01, -6.9910e-01, 3.4893e-01,
-8.7715e-01],
[-9.7867e-01, -9.5280e-01, 2.3949e-01, -5.0913e-01, 2.8120e-01,
6.5280e-01, 1.5024e+00, -8.8424e-01, -4.7504e-01, 4.8769e-01,
1.1893e+00, -6.1249e-01, -6.9277e-01, -2.6621e-01, -1.0576e+00,
5.5870e-01],
[ 1.5670e+00, 1.4179e+00, -4.3079e-01, 9.2860e-03, 9.1036e-02,
-8.8635e-01, -7.3382e-01, 9.9911e-02, 1.0422e+00, 3.3808e-02,
-4.3342e-01, -4.2554e-01, -1.2664e+00, 6.7073e-01, -4.3843e-01,
2.4629e-01],
[-4.0307e-01, -8.0981e-01, -7.3736e-01, 7.0317e-01, -6.7180e-01,
3.4013e-02, 1.8024e-01, 6.2083e-01, -3.4933e-01, -6.2524e-01,
-3.5893e-01, 1.1627e+00, -2.3080e-01, -9.5184e-01, 7.1967e-01,
-6.2196e-01],
[ 5.5927e-01, -8.6444e-01, 1.4131e-01, 1.0069e+00, -5.3246e-01,
-1.3804e+00, -1.2412e+00, -6.2083e-01, 1.2844e+00, -8.0976e-01,
-1.3851e+00, -2.3795e-01, 7.9758e-01, -1.4378e+00, -1.5737e+00,
-8.9958e-01],
[-1.7257e-01, -8.2184e-01, -1.2956e+00, 1.5484e+00, -5.5334e-02,
9.2041e-01, -8.8723e-01, 6.2453e-01, -1.0932e-02, 1.8479e+00,
1.9423e-01, 9.7021e-01, 2.9241e-01, 7.9268e-01, -5.6227e-01,
-1.6743e+00],
[ 1.2613e-01, 1.2770e-01, 7.3058e-01, -6.5678e-01, -9.4056e-01,
-1.2946e+00, -7.7965e-01, 1.2064e-01, -1.8496e+00, 1.9518e+00,
-1.6807e+00, 9.9174e-01, 5.3432e-01, -2.4413e-01, -9.0537e-01,
8.8707e-01],
[ 1.6266e+00, -1.2979e+00, 2.1201e-01, -1.2506e+00, -2.5751e-01,
-1.4188e+00, 6.8915e-01, 7.5310e-01, 1.0268e+00, -4.7283e-02,
6.0086e-01, -2.2722e+00, 1.4796e+00, 1.6694e-01, 9.7610e-01,
-3.5627e-01],
[ 6.9922e-01, -2.8308e-01, 1.8916e+00, 4.8927e-01, -1.2071e+00,
5.6262e-01, 7.0220e-01, -9.3077e-01, -4.1099e-01, -8.8629e-01,
-1.4075e+00, 1.4386e-01, -1.2417e+00, 3.6809e-02, 3.5345e-01,
-1.0575e+00],
[-6.3683e-01, 9.1762e-01, 4.9119e-01, -2.2506e+00, -8.0021e-01,
2.0614e-01, -1.2295e+00, -1.1974e+00, 2.1372e-01, -9.5220e-01,
-1.9652e+00, -4.0965e-02, 1.4382e+00, -1.1599e+00, 6.7839e-01,
9.6810e-01],
[ 3.6581e-01, -3.6475e-01, 7.9087e-01, -9.9523e-01, 1.0741e+00,
3.9176e-01, 3.8599e-01, 1.0822e+00, -2.7792e-01, -1.0478e-01,
1.7455e+00, -1.2391e+00, 4.7921e-01, 5.7139e-01, 7.3661e-01,
-3.4478e-01],
[ 1.3962e-01, 1.7028e-01, -3.8022e-01, 2.7366e+00, -4.0408e-01,
2.1577e+00, -1.1733e-01, 2.2820e+00, -1.0575e+00, -7.8049e-02,
-1.8555e+00, -1.5639e+00, -2.2194e-03, 9.1327e-01, -6.8152e-01,
4.3902e-01],
[ 1.3193e+00, 2.4797e-01, 7.9764e-01, 9.9411e-02, -5.5226e-01,
-6.8628e-02, -3.8057e-01, 1.5772e+00, 8.9885e-01, -6.5663e-01,
2.7259e+00, 5.9226e-01, -2.0406e+00, 5.5416e-01, -2.2941e-01,
-4.6528e-01]]])
x_postion_embedding.shape#每个句子的相同位置都是使用的相同的位置向量。
torch.Size([1, 18, 16])
x_embedding = x_word_embedding + x_postion_embedding
x_embedding
tensor([[[ 1.4990, -0.1813, -1.2405, ..., -0.5678, -4.0296, -0.5216],
[-0.7586, 2.4355, -0.7921, ..., -1.4273, -2.3320, 0.0714],
[-1.4513, -0.1709, -0.1603, ..., -2.1290, -0.4146, -0.2798],
...,
[ 0.3658, -0.3647, 0.7909, ..., 0.5714, 0.7366, -0.3448],
[ 0.1396, 0.1703, -0.3802, ..., 0.9133, -0.6815, 0.4390],
[ 1.3193, 0.2480, 0.7976, ..., 0.5542, -0.2294, -0.4653]],
[[ 1.2682, -0.8150, -0.8226, ..., 0.0490, -1.8821, 0.6543],
[ 1.2540, 2.3416, 0.3683, ..., 0.4880, 0.1284, -1.2846],
[-0.9080, 0.3243, -2.1975, ..., -2.0137, 0.2239, -0.0376],
...,
[ 0.3658, -0.3647, 0.7909, ..., 0.5714, 0.7366, -0.3448],
[ 0.1396, 0.1703, -0.3802, ..., 0.9133, -0.6815, 0.4390],
[ 1.3193, 0.2480, 0.7976, ..., 0.5542, -0.2294, -0.4653]],
[[ 1.8732, 0.5552, -1.7469, ..., 1.8612, -2.0750, 0.5315],
[ 0.3223, 2.0160, 2.4301, ..., -1.7063, -1.3378, 0.2593],
[-0.3029, -1.4872, -1.3034, ..., 0.5289, -0.0337, 0.4721],
...,
[-0.1261, -1.1610, 0.8987, ..., 1.1619, 2.2077, -0.1100],
[ 0.1396, 0.1703, -0.3802, ..., 0.9133, -0.6815, 0.4390],
[ 1.3193, 0.2480, 0.7976, ..., 0.5542, -0.2294, -0.4653]],
...,
[[ 0.6126, 0.9536, -1.1994, ..., 1.2692, -0.6369, 1.7213],
[ 2.0293, 2.3032, 0.5438, ..., -1.0498, -0.9796, -0.5360],
[ 0.2768, -1.1700, -3.0397, ..., -0.9378, 0.1812, -0.1183],
...,
[ 0.3658, -0.3647, 0.7909, ..., 0.5714, 0.7366, -0.3448],
[ 0.1396, 0.1703, -0.3802, ..., 0.9133, -0.6815, 0.4390],
[ 1.3193, 0.2480, 0.7976, ..., 0.5542, -0.2294, -0.4653]],
[[ 0.2253, -0.3884, -1.9119, ..., -1.1351, -1.7859, 0.1309],
[ 0.7697, 1.8575, 0.9080, ..., -0.9043, 0.5719, -2.3765],
[-0.3397, -1.2380, 0.0837, ..., -1.6820, 0.8802, -0.0421],
...,
[ 0.5816, -1.1825, 0.6010, ..., 0.6708, -0.8309, -0.3200],
[ 1.7076, -0.1556, -0.1022, ..., 1.1021, 0.6713, 0.8018],
[ 2.7562, -0.3169, 0.5713, ..., 0.3990, -0.1851, -0.3272]],
[[ 2.5735, 1.8996, 0.4187, ..., -0.0281, -1.4581, 0.0753],
[-0.0460, 1.7525, 1.3994, ..., -0.5160, -2.1990, -2.7955],
[ 0.3476, -0.7091, -1.4106, ..., -1.6300, -0.2675, 0.6924],
...,
[ 0.3658, -0.3647, 0.7909, ..., 0.5714, 0.7366, -0.3448],
[ 0.1396, 0.1703, -0.3802, ..., 0.9133, -0.6815, 0.4390],
[ 1.3193, 0.2480, 0.7976, ..., 0.5542, -0.2294, -0.4653]]])
x_embedding.shape
torch.Size([10, 18, 16])
batch_size = 10
seq_length = 18
embedding_size = 16
q_d=k_d = v_d= 20
q = nn.Linear(embedding_size,q_d,bias=False)(x_embedding)
k = nn.Linear(embedding_size,k_d,bias=False)(x_embedding)
v = nn.Linear(embedding_size,v_d,bias=False)(x_embedding)
q.shape,k.shape,v.shape
(torch.Size([10, 18, 20]), torch.Size([10, 18, 20]), torch.Size([10, 18, 20]))
attn = torch.matmul(q / 4, k.transpose(1, 2))
attn.shape
torch.Size([10, 18, 18])
attn[0].shape
torch.Size([18, 18])
attn = attn.masked_fill(src_mask == 0, -1e9)
attn[0]
tensor([[-6.0690e-01, -3.7683e-01, 1.5225e+00, 9.6109e-01, 9.9332e-01,
1.8536e-01, 7.5841e-01, 8.1171e-01, -4.7227e-01, -1.6088e+00,
6.3255e-01, -2.4172e-01, -3.0882e-01, 1.3427e+00, -1.1389e+00,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-3.0765e-01, -1.5715e-01, 4.4040e-01, 4.1949e-01, -1.4558e-01,
-2.8926e-01, 5.3010e-01, 5.7046e-02, -6.6750e-01, -1.2537e+00,
2.4509e-01, -4.4911e-01, 4.4740e-01, 1.5970e+00, -6.1908e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.1161e-01, 1.1511e+00, -1.7925e+00, 2.2255e-01, -1.0840e+00,
-7.0142e-02, -5.0922e-01, 1.7571e-01, -5.4323e-01, -4.8321e-01,
-1.2211e+00, 3.6905e-01, -1.1800e+00, 8.3300e-01, 1.5945e+00,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[ 5.4233e-02, -5.0000e-02, -6.8838e-01, -1.3039e-01, -5.2410e-01,
-5.5713e-01, 1.9288e-02, 5.0029e-02, -7.9929e-02, 4.1053e-01,
-3.4310e-01, -1.2978e-01, 3.6358e-01, 3.4465e-01, 5.2046e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-8.8978e-02, 8.0073e-01, 7.4081e-02, 4.5606e-01, 1.7942e-01,
-5.0651e-02, 2.7160e-01, 6.4326e-01, -3.5353e-01, -1.7773e-01,
-2.0484e-01, 6.1417e-03, -1.0183e+00, 5.1489e-01, 1.6607e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-3.6375e-01, 9.5096e-01, -1.0274e+00, -1.5157e-01, -1.2807e+00,
-2.5263e-01, 3.9632e-02, 6.1543e-01, 4.6237e-01, 3.3684e-01,
-6.3857e-01, -4.8063e-01, -2.0782e-01, -5.5224e-01, 1.7934e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[ 2.8373e-01, -1.0136e-01, 1.1654e+00, 9.3107e-01, 1.0568e+00,
-7.8805e-02, 5.1141e-01, -9.4328e-01, -6.2343e-01, 4.3079e-01,
1.0431e+00, -1.0827e-01, 7.1189e-01, 3.3160e-01, -1.1183e+00,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[ 4.8535e-01, -1.1663e+00, 6.6636e-01, -9.6097e-01, 8.5208e-01,
9.9473e-01, -5.2427e-02, 4.5534e-01, 4.1307e-01, -8.3508e-01,
1.5996e+00, 8.0137e-01, 1.3822e+00, -2.4668e+00, -2.2721e+00,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-4.1034e-02, -6.5552e-01, 2.6044e-01, 2.9065e-01, -3.2397e-01,
2.3684e-01, 8.1972e-03, 1.0512e+00, -7.9683e-01, 4.2159e-01,
-1.9782e-01, 1.7893e+00, 7.8592e-01, 4.2542e-02, 2.4891e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[ 6.4500e-02, -1.8950e+00, -2.1693e-01, 1.0840e-01, 2.2477e-01,
6.7428e-01, -3.7683e-01, 3.2650e-01, 5.2687e-01, 5.6838e-01,
1.4969e-01, 1.8243e+00, 1.1127e+00, -2.8123e-01, 4.6841e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[ 8.2095e-02, 4.1222e-01, 3.1565e-01, 8.2741e-01, 4.0937e-01,
-6.6774e-01, 9.6502e-01, 4.8628e-02, -1.1333e+00, -2.4812e-01,
2.4803e-01, -2.0652e-01, 9.9369e-03, 1.1184e+00, -2.5702e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-8.6849e-01, -1.8334e-01, -4.9383e-01, -2.5585e-01, -3.5734e-01,
6.4596e-01, -6.3123e-01, -1.7421e-01, 2.5318e+00, -2.5341e-01,
5.8007e-01, -1.5293e+00, 2.2613e-01, -2.1622e+00, -1.7305e+00,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[ 4.9297e-01, -8.2548e-01, 3.7337e-01, 3.2121e-01, 1.0958e+00,
-3.7185e-01, 4.2499e-01, -9.7494e-01, 1.0195e-01, 8.9203e-01,
6.9874e-01, -1.8231e-01, 7.8013e-01, -3.2680e-01, -2.0724e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.1754e-02, -2.1956e-01, -1.8653e-01, 5.9911e-02, 4.4472e-01,
1.3811e-01, -6.3979e-01, -5.6150e-02, 5.5480e-01, 2.0387e-01,
-7.0942e-01, -1.7152e-01, -1.0433e+00, 1.2044e+00, 1.6884e+00,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.0712e-01, -6.6093e-01, 1.4028e-02, -9.5958e-01, 7.4170e-02,
1.2432e+00, -1.2309e+00, 3.3976e-01, 2.2499e+00, 9.1287e-01,
-3.0829e-01, 3.4668e-01, -6.4126e-01, -1.6625e+00, 6.1514e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.0567e-01, 6.0719e-02, -2.0683e-01, -2.4056e-01, -1.4933e-01,
-3.5809e-01, 1.3756e-01, -2.6875e-01, 4.9080e-01, 1.6693e-01,
-1.7357e-01, -8.6498e-01, -1.1757e-01, -3.3879e-02, 1.9470e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[-2.3758e-01, 4.7806e-02, 1.3244e+00, -1.9672e-01, -5.8326e-02,
-4.1213e-01, 3.4664e-01, 6.9897e-01, -6.6317e-01, -6.6268e-01,
3.4629e-01, -4.4367e-01, 4.2473e-01, 5.9423e-02, -1.2536e+00,
-1.0000e+09, -1.0000e+09, -1.0000e+09],
[ 1.0423e-01, -4.4003e-01, 9.3847e-01, -1.3148e-01, 4.3072e-01,
-2.8801e-01, 4.1276e-01, 3.0477e-01, -9.8558e-01, -1.9937e-01,
1.7873e-01, 3.9120e-01, 3.3835e-01, 3.1912e-01, -1.6030e-01,
-1.0000e+09, -1.0000e+09, -1.0000e+09]], grad_fn=<SelectBackward0>)
scores = F.softmax(attn,dim=-1)
scores[0]
tensor([[0.0217, 0.0273, 0.1827, 0.1042, 0.1076, 0.0480, 0.0851, 0.0897, 0.0248,
0.0080, 0.0750, 0.0313, 0.0293, 0.1526, 0.0128, 0.0000, 0.0000, 0.0000],
[0.0391, 0.0455, 0.0827, 0.0810, 0.0460, 0.0399, 0.0904, 0.0563, 0.0273,
0.0152, 0.0680, 0.0340, 0.0832, 0.2628, 0.0287, 0.0000, 0.0000, 0.0000],
[0.0471, 0.1663, 0.0088, 0.0657, 0.0178, 0.0490, 0.0316, 0.0627, 0.0306,
0.0325, 0.0155, 0.0761, 0.0162, 0.1210, 0.2592, 0.0000, 0.0000, 0.0000],
[0.0696, 0.0627, 0.0331, 0.0578, 0.0390, 0.0378, 0.0672, 0.0693, 0.0608,
0.0994, 0.0468, 0.0579, 0.0948, 0.0930, 0.1109, 0.0000, 0.0000, 0.0000],
[0.0516, 0.1256, 0.0607, 0.0890, 0.0675, 0.0536, 0.0740, 0.1073, 0.0396,
0.0472, 0.0459, 0.0567, 0.0204, 0.0944, 0.0666, 0.0000, 0.0000, 0.0000],
[0.0458, 0.1707, 0.0236, 0.0567, 0.0183, 0.0512, 0.0686, 0.1220, 0.1047,
0.0923, 0.0348, 0.0408, 0.0536, 0.0380, 0.0789, 0.0000, 0.0000, 0.0000],
[0.0567, 0.0386, 0.1370, 0.1084, 0.1229, 0.0395, 0.0713, 0.0166, 0.0229,
0.0657, 0.1213, 0.0383, 0.0871, 0.0595, 0.0140, 0.0000, 0.0000, 0.0000],
[0.0646, 0.0124, 0.0775, 0.0152, 0.0933, 0.1076, 0.0378, 0.0627, 0.0601,
0.0173, 0.1970, 0.0887, 0.1585, 0.0034, 0.0041, 0.0000, 0.0000, 0.0000],
[0.0412, 0.0223, 0.0557, 0.0575, 0.0311, 0.0544, 0.0433, 0.1229, 0.0194,
0.0655, 0.0353, 0.2572, 0.0943, 0.0448, 0.0551, 0.0000, 0.0000, 0.0000],
[0.0433, 0.0061, 0.0327, 0.0452, 0.0508, 0.0797, 0.0278, 0.0563, 0.0687,
0.0717, 0.0471, 0.2516, 0.1235, 0.0306, 0.0648, 0.0000, 0.0000, 0.0000],
[0.0543, 0.0756, 0.0686, 0.1144, 0.0753, 0.0257, 0.1313, 0.0525, 0.0161,
0.0390, 0.0641, 0.0407, 0.0505, 0.1531, 0.0387, 0.0000, 0.0000, 0.0000],
[0.0178, 0.0354, 0.0260, 0.0329, 0.0297, 0.0811, 0.0226, 0.0357, 0.5348,
0.0330, 0.0759, 0.0092, 0.0533, 0.0049, 0.0075, 0.0000, 0.0000, 0.0000],
[0.0795, 0.0213, 0.0705, 0.0669, 0.1452, 0.0335, 0.0742, 0.0183, 0.0537,
0.1184, 0.0976, 0.0404, 0.1059, 0.0350, 0.0395, 0.0000, 0.0000, 0.0000],
[0.0465, 0.0378, 0.0390, 0.0499, 0.0734, 0.0540, 0.0248, 0.0445, 0.0819,
0.0577, 0.0231, 0.0396, 0.0166, 0.1568, 0.2545, 0.0000, 0.0000, 0.0000],
[0.0349, 0.0201, 0.0394, 0.0149, 0.0418, 0.1347, 0.0113, 0.0546, 0.3685,
0.0968, 0.0285, 0.0549, 0.0205, 0.0074, 0.0719, 0.0000, 0.0000, 0.0000],
[0.0634, 0.0749, 0.0573, 0.0554, 0.0607, 0.0493, 0.0809, 0.0539, 0.1152,
0.0833, 0.0593, 0.0297, 0.0627, 0.0682, 0.0857, 0.0000, 0.0000, 0.0000],
[0.0453, 0.0602, 0.2159, 0.0472, 0.0542, 0.0380, 0.0812, 0.1155, 0.0296,
0.0296, 0.0812, 0.0369, 0.0878, 0.0609, 0.0164, 0.0000, 0.0000, 0.0000],
[0.0622, 0.0361, 0.1433, 0.0491, 0.0862, 0.0420, 0.0847, 0.0760, 0.0209,
0.0459, 0.0670, 0.0829, 0.0786, 0.0771, 0.0478, 0.0000, 0.0000, 0.0000]],
grad_fn=<SelectBackward0>)
sum(scores[0][0])
tensor(1.0000, grad_fn=<AddBackward0>)
scores.shape
torch.Size([10, 18, 18])
v.shape
torch.Size([10, 18, 20])
v[0].shape
torch.Size([18, 20])
scores[0]
tensor([[0.0217, 0.0273, 0.1827, 0.1042, 0.1076, 0.0480, 0.0851, 0.0897, 0.0248,
0.0080, 0.0750, 0.0313, 0.0293, 0.1526, 0.0128, 0.0000, 0.0000, 0.0000],
[0.0391, 0.0455, 0.0827, 0.0810, 0.0460, 0.0399, 0.0904, 0.0563, 0.0273,
0.0152, 0.0680, 0.0340, 0.0832, 0.2628, 0.0287, 0.0000, 0.0000, 0.0000],
[0.0471, 0.1663, 0.0088, 0.0657, 0.0178, 0.0490, 0.0316, 0.0627, 0.0306,
0.0325, 0.0155, 0.0761, 0.0162, 0.1210, 0.2592, 0.0000, 0.0000, 0.0000],
[0.0696, 0.0627, 0.0331, 0.0578, 0.0390, 0.0378, 0.0672, 0.0693, 0.0608,
0.0994, 0.0468, 0.0579, 0.0948, 0.0930, 0.1109, 0.0000, 0.0000, 0.0000],
[0.0516, 0.1256, 0.0607, 0.0890, 0.0675, 0.0536, 0.0740, 0.1073, 0.0396,
0.0472, 0.0459, 0.0567, 0.0204, 0.0944, 0.0666, 0.0000, 0.0000, 0.0000],
[0.0458, 0.1707, 0.0236, 0.0567, 0.0183, 0.0512, 0.0686, 0.1220, 0.1047,
0.0923, 0.0348, 0.0408, 0.0536, 0.0380, 0.0789, 0.0000, 0.0000, 0.0000],
[0.0567, 0.0386, 0.1370, 0.1084, 0.1229, 0.0395, 0.0713, 0.0166, 0.0229,
0.0657, 0.1213, 0.0383, 0.0871, 0.0595, 0.0140, 0.0000, 0.0000, 0.0000],
[0.0646, 0.0124, 0.0775, 0.0152, 0.0933, 0.1076, 0.0378, 0.0627, 0.0601,
0.0173, 0.1970, 0.0887, 0.1585, 0.0034, 0.0041, 0.0000, 0.0000, 0.0000],
[0.0412, 0.0223, 0.0557, 0.0575, 0.0311, 0.0544, 0.0433, 0.1229, 0.0194,
0.0655, 0.0353, 0.2572, 0.0943, 0.0448, 0.0551, 0.0000, 0.0000, 0.0000],
[0.0433, 0.0061, 0.0327, 0.0452, 0.0508, 0.0797, 0.0278, 0.0563, 0.0687,
0.0717, 0.0471, 0.2516, 0.1235, 0.0306, 0.0648, 0.0000, 0.0000, 0.0000],
[0.0543, 0.0756, 0.0686, 0.1144, 0.0753, 0.0257, 0.1313, 0.0525, 0.0161,
0.0390, 0.0641, 0.0407, 0.0505, 0.1531, 0.0387, 0.0000, 0.0000, 0.0000],
[0.0178, 0.0354, 0.0260, 0.0329, 0.0297, 0.0811, 0.0226, 0.0357, 0.5348,
0.0330, 0.0759, 0.0092, 0.0533, 0.0049, 0.0075, 0.0000, 0.0000, 0.0000],
[0.0795, 0.0213, 0.0705, 0.0669, 0.1452, 0.0335, 0.0742, 0.0183, 0.0537,
0.1184, 0.0976, 0.0404, 0.1059, 0.0350, 0.0395, 0.0000, 0.0000, 0.0000],
[0.0465, 0.0378, 0.0390, 0.0499, 0.0734, 0.0540, 0.0248, 0.0445, 0.0819,
0.0577, 0.0231, 0.0396, 0.0166, 0.1568, 0.2545, 0.0000, 0.0000, 0.0000],
[0.0349, 0.0201, 0.0394, 0.0149, 0.0418, 0.1347, 0.0113, 0.0546, 0.3685,
0.0968, 0.0285, 0.0549, 0.0205, 0.0074, 0.0719, 0.0000, 0.0000, 0.0000],
[0.0634, 0.0749, 0.0573, 0.0554, 0.0607, 0.0493, 0.0809, 0.0539, 0.1152,
0.0833, 0.0593, 0.0297, 0.0627, 0.0682, 0.0857, 0.0000, 0.0000, 0.0000],
[0.0453, 0.0602, 0.2159, 0.0472, 0.0542, 0.0380, 0.0812, 0.1155, 0.0296,
0.0296, 0.0812, 0.0369, 0.0878, 0.0609, 0.0164, 0.0000, 0.0000, 0.0000],
[0.0622, 0.0361, 0.1433, 0.0491, 0.0862, 0.0420, 0.0847, 0.0760, 0.0209,
0.0459, 0.0670, 0.0829, 0.0786, 0.0771, 0.0478, 0.0000, 0.0000, 0.0000]],
grad_fn=<SelectBackward0>)
output = torch.matmul(scores, v)
output.shape
torch.Size([10, 18, 20])
output[0][0]
tensor([ 0.1268, 0.2170, -0.2703, 0.2993, -0.2006, -0.3046, 0.2196, -0.0223,
0.7093, -0.1499, -0.0298, -0.1743, 0.1631, -0.1972, -0.3261, 0.1589,
-0.1068, 0.0306, -0.2259, 0.0282], grad_fn=<SelectBackward0>)
scores.shape
torch.Size([10, 18, 18])
scores[0][0]
tensor([0.0217, 0.0273, 0.1827, 0.1042, 0.1076, 0.0480, 0.0851, 0.0897, 0.0248,
0.0080, 0.0750, 0.0313, 0.0293, 0.1526, 0.0128, 0.0000, 0.0000, 0.0000],
grad_fn=<SelectBackward0>)
v.shape
torch.Size([10, 18, 20])
#不用矩阵算是为了让大家更清楚它的计算过程,实际应用中我们是矩阵运算。
output1 = []
for i in range(len(scores[0][0])):
output1.append(scores[0][0][i].detach().numpy()*v[0][i].detach().numpy())
output1
[array([-0.00356334, 0.01569117, -0.02043939, -0.02862189, 0.00840736,
-0.01401182, 0.00872742, 0.03408484, 0.00019593, -0.00853396,
0.00951094, -0.01021247, 0.01006769, 0.04058198, 0.02450397,
-0.02066059, 0.02544492, 0.01739634, -0.00912306, -0.00440777],
dtype=float32),
array([ 0.02174385, 0.02181036, -0.00668943, -0.00284444, -0.01792286,
-0.05496887, 0.00642328, 0.06035611, 0.00246261, 0.03883798,
0.00429785, -0.00607146, 0.05036234, -0.01484454, 0.06557821,
-0.01962684, 0.02545645, 0.03820134, 0.00964658, -0.0618248 ],
dtype=float32),
array([ 0.00981068, 0.08241012, 0.0435107 , 0.26321438, -0.21280633,
-0.065763 , 0.22667795, -0.11789013, 0.1957324 , -0.09090293,
0.13033353, -0.17513736, -0.02819123, -0.27770212, -0.16380529,
-0.16696015, -0.22296761, -0.15442224, -0.14996761, -0.04623051],
dtype=float32),
array([ 0.05863579, -0.01563031, 0.02628604, 0.04352596, 0.00123114,
-0.12593815, -0.05850247, -0.04619041, 0.14190896, -0.00667505,
-0.06504327, -0.02457336, 0.01267471, -0.08988734, 0.02068387,
0.1141569 , -0.05952757, 0.09277385, 0.04706449, -0.00445588],
dtype=float32),
array([-0.01174144, 0.02150498, -0.06701917, -0.08503368, 0.05980498,
0.02390346, -0.06823318, -0.02788872, 0.00625273, -0.08193176,
-0.08811162, -0.04064612, -0.01717038, 0.01876009, -0.05648129,
-0.03940401, 0.04552145, 0.05847974, 0.02593314, 0.08681556],
dtype=float32),
array([ 0.0046064 , 0.0753553 , -0.05663994, -0.07728799, -0.05931624,
0.05093202, 0.03476338, 0.00069166, 0.00120347, 0.01253372,
-0.00336849, 0.00602924, 0.03808581, 0.06855204, 0.00205304,
-0.04280849, 0.02934068, 0.04572994, -0.10011108, 0.05379374],
dtype=float32),
array([ 0.00855533, 0.01972056, -0.02584991, 0.012734 , 0.04426347,
-0.01907244, -0.06026317, 0.08408913, -0.08700912, 0.02952116,
0.09744457, -0.05496382, 0.05024561, -0.04637939, -0.04548782,
0.02788595, 0.01097868, -0.02408975, 0.06984393, -0.06373438],
dtype=float32),
array([ 5.0336070e-02, -9.1202877e-02, 1.1917996e-02, 1.2563699e-01,
-4.5944070e-03, 5.7880916e-03, 1.5337752e-01, 4.5809049e-02,
8.7222785e-02, -7.0060484e-02, -1.0104631e-01, 6.6017680e-02,
8.4173165e-02, 2.7579835e-02, -6.1708100e-02, 5.5133272e-02,
3.2097664e-02, -2.4366794e-02, 1.0262818e-04, 8.8561904e-03],
dtype=float32),
array([-0.0292644 , 0.01630488, -0.01253164, -0.0070375 , -0.03961598,
0.05226943, 0.02208326, -0.03990135, 0.01585516, -0.00127062,
-0.0113592 , 0.02213981, 0.02005158, 0.01969342, -0.0070779 ,
-0.00585135, -0.00277652, -0.00163333, -0.03733382, 0.01124388],
dtype=float32),
array([ 6.2298100e-03, 1.8686900e-03, 2.7403934e-03, -3.5018041e-03,
-4.4463761e-03, -4.6206997e-03, 1.9569243e-03, 2.6030848e-03,
1.3119811e-02, 8.0437390e-03, 2.0034057e-03, -2.2230356e-03,
7.5591272e-03, -2.4101253e-04, 8.6891865e-03, -6.3293730e-05,
-9.3702022e-03, 2.1057955e-03, 1.9265222e-03, -1.9525980e-03],
dtype=float32),
array([-0.00207176, 0.03381819, -0.00062723, -0.05788399, 0.03948465,
0.03536145, 0.00404582, -0.01121849, -0.05850476, -0.0454405 ,
0.00821594, -0.05255793, -0.06612497, 0.04488177, -0.0387372 ,
-0.02739964, -0.03441686, 0.03766139, 0.0335264 , 0.05324583],
dtype=float32),
array([ 0.03819073, -0.02459533, -0.03861014, -0.01953255, 0.0140466 ,
-0.03961703, -0.01613682, 0.0444642 , 0.01884853, -0.0135273 ,
-0.01088831, 0.00310303, -0.00310043, 0.03728347, -0.01751549,
0.00054721, 0.04011041, 0.01901039, -0.03186378, 0.05084344],
dtype=float32),
array([-0.01383471, -0.03603299, 0.0306312 , -0.02414859, 0.02562755,
0.02530815, 0.00139578, -0.02436204, -0.02043813, -0.01283113,
0.0208585 , -0.01336401, -0.07284833, 0.03085884, -0.04238587,
0.02426553, -0.05287703, -0.0112904 , 0.03723892, 0.02734692],
dtype=float32),
array([-0.00945639, 0.09796111, -0.14330448, 0.1691841 , -0.0494328 ,
-0.17400241, -0.02735225, -0.03271891, 0.3773934 , 0.0788412 ,
-0.01655269, 0.09439775, 0.05920107, -0.06631016, -0.02382696,
0.2521283 , 0.04862823, -0.06399021, -0.11209722, -0.07742304],
dtype=float32),
array([-0.00139221, -0.00199078, -0.01365093, -0.009141 , -0.00532917,
-0.00019725, -0.00937385, 0.00572658, 0.01502065, 0.01348137,
-0.00606315, 0.0137309 , 0.01815296, 0.00999111, 0.0094426 ,
0.00756134, 0.01759486, -0.00099205, -0.01064409, -0.00396364],
dtype=float32),
array([-0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., -0., -0.,
0., -0., 0., -0., 0., 0., 0.], dtype=float32),
array([-0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., -0.], dtype=float32),
array([-0., -0., 0., 0., 0., 0., -0., -0., -0., -0., 0., -0., -0.,
0., -0., 0., -0., -0., 0., 0.], dtype=float32)]
import numpy as np
np.array(output1).reshape([18,-1])
array([[-3.56333586e-03, 1.56911686e-02, -2.04393901e-02,
-2.86218859e-02, 8.40735529e-03, -1.40118189e-02,
8.72741546e-03, 3.40848416e-02, 1.95933360e-04,
-8.53395835e-03, 9.51094087e-03, -1.02124671e-02,
1.00676939e-02, 4.05819751e-02, 2.45039724e-02,
-2.06605904e-02, 2.54449248e-02, 1.73963401e-02,
-9.12306458e-03, -4.40777233e-03],
[ 2.17438508e-02, 2.18103621e-02, -6.68943441e-03,
-2.84443772e-03, -1.79228578e-02, -5.49688675e-02,
6.42327731e-03, 6.03561103e-02, 2.46260897e-03,
3.88379842e-02, 4.29785158e-03, -6.07146276e-03,
5.03623448e-02, -1.48445424e-02, 6.55782074e-02,
-1.96268391e-02, 2.54564472e-02, 3.82013433e-02,
9.64657590e-03, -6.18248023e-02],
[ 9.81067866e-03, 8.24101195e-02, 4.35107015e-02,
2.63214380e-01, -2.12806329e-01, -6.57629967e-02,
2.26677954e-01, -1.17890127e-01, 1.95732400e-01,
-9.09029320e-02, 1.30333528e-01, -1.75137356e-01,
-2.81912331e-02, -2.77702123e-01, -1.63805291e-01,
-1.66960150e-01, -2.22967610e-01, -1.54422238e-01,
-1.49967611e-01, -4.62305136e-02],
[ 5.86357936e-02, -1.56303123e-02, 2.62860432e-02,
4.35259603e-02, 1.23114337e-03, -1.25938147e-01,
-5.85024655e-02, -4.61904071e-02, 1.41908959e-01,
-6.67505478e-03, -6.50432706e-02, -2.45733596e-02,
1.26747135e-02, -8.98873433e-02, 2.06838697e-02,
1.14156902e-01, -5.95275685e-02, 9.27738547e-02,
4.70644869e-02, -4.45587980e-03],
[-1.17414435e-02, 2.15049833e-02, -6.70191720e-02,
-8.50336775e-02, 5.98049797e-02, 2.39034556e-02,
-6.82331845e-02, -2.78887209e-02, 6.25273399e-03,
-8.19317624e-02, -8.81116167e-02, -4.06461246e-02,
-1.71703752e-02, 1.87600926e-02, -5.64812906e-02,
-3.94040123e-02, 4.55214530e-02, 5.84797412e-02,
2.59331409e-02, 8.68155584e-02],
[ 4.60639922e-03, 7.53552988e-02, -5.66399395e-02,
-7.72879943e-02, -5.93162440e-02, 5.09320162e-02,
3.47633846e-02, 6.91655499e-04, 1.20347342e-03,
1.25337243e-02, -3.36848991e-03, 6.02923892e-03,
3.80858146e-02, 6.85520396e-02, 2.05303659e-03,
-4.28084917e-02, 2.93406751e-02, 4.57299352e-02,
-1.00111082e-01, 5.37937433e-02],
[ 8.55532568e-03, 1.97205599e-02, -2.58499123e-02,
1.27340043e-02, 4.42634709e-02, -1.90724358e-02,
-6.02631718e-02, 8.40891302e-02, -8.70091245e-02,
2.95211598e-02, 9.74445716e-02, -5.49638234e-02,
5.02456091e-02, -4.63793948e-02, -4.54878211e-02,
2.78859455e-02, 1.09786754e-02, -2.40897462e-02,
6.98439255e-02, -6.37343824e-02],
[ 5.03360704e-02, -9.12028775e-02, 1.19179962e-02,
1.25636995e-01, -4.59440704e-03, 5.78809157e-03,
1.53377518e-01, 4.58090492e-02, 8.72227848e-02,
-7.00604841e-02, -1.01046309e-01, 6.60176799e-02,
8.41731653e-02, 2.75798347e-02, -6.17081001e-02,
5.51332720e-02, 3.20976637e-02, -2.43667942e-02,
1.02628183e-04, 8.85619037e-03],
[-2.92643961e-02, 1.63048804e-02, -1.25316372e-02,
-7.03750225e-03, -3.96159813e-02, 5.22694290e-02,
2.20832583e-02, -3.99013534e-02, 1.58551577e-02,
-1.27062469e-03, -1.13591971e-02, 2.21398100e-02,
2.00515836e-02, 1.96934212e-02, -7.07789976e-03,
-5.85135119e-03, -2.77651520e-03, -1.63333223e-03,
-3.73338237e-02, 1.12438807e-02],
[ 6.22980995e-03, 1.86869001e-03, 2.74039339e-03,
-3.50180408e-03, -4.44637612e-03, -4.62069968e-03,
1.95692433e-03, 2.60308478e-03, 1.31198112e-02,
8.04373901e-03, 2.00340571e-03, -2.22303555e-03,
7.55912717e-03, -2.41012531e-04, 8.68918654e-03,
-6.32937299e-05, -9.37020220e-03, 2.10579555e-03,
1.92652224e-03, -1.95259799e-03],
[-2.07175617e-03, 3.38181928e-02, -6.27233065e-04,
-5.78839891e-02, 3.94846499e-02, 3.53614539e-02,
4.04581567e-03, -1.12184929e-02, -5.85047640e-02,
-4.54404950e-02, 8.21593590e-03, -5.25579341e-02,
-6.61249682e-02, 4.48817722e-02, -3.87372039e-02,
-2.73996368e-02, -3.44168618e-02, 3.76613885e-02,
3.35264020e-02, 5.32458313e-02],
[ 3.81907299e-02, -2.45953333e-02, -3.86101417e-02,
-1.95325539e-02, 1.40466047e-02, -3.96170318e-02,
-1.61368232e-02, 4.44641970e-02, 1.88485272e-02,
-1.35272965e-02, -1.08883120e-02, 3.10303201e-03,
-3.10042920e-03, 3.72834690e-02, -1.75154917e-02,
5.47210744e-04, 4.01104130e-02, 1.90103948e-02,
-3.18637751e-02, 5.08434363e-02],
[-1.38347102e-02, -3.60329933e-02, 3.06312013e-02,
-2.41485890e-02, 2.56275497e-02, 2.53081508e-02,
1.39578118e-03, -2.43620351e-02, -2.04381309e-02,
-1.28311319e-02, 2.08585002e-02, -1.33640114e-02,
-7.28483349e-02, 3.08588445e-02, -4.23858687e-02,
2.42655315e-02, -5.28770313e-02, -1.12903966e-02,
3.72389220e-02, 2.73469202e-02],
[-9.45638865e-03, 9.79611129e-02, -1.43304482e-01,
1.69184104e-01, -4.94327955e-02, -1.74002409e-01,
-2.73522511e-02, -3.27189118e-02, 3.77393395e-01,
7.88412020e-02, -1.65526886e-02, 9.43977535e-02,
5.92010729e-02, -6.63101599e-02, -2.38269642e-02,
2.52128303e-01, 4.86282334e-02, -6.39902055e-02,
-1.12097219e-01, -7.74230361e-02],
[-1.39220979e-03, -1.99078326e-03, -1.36509314e-02,
-9.14100278e-03, -5.32916794e-03, -1.97253496e-04,
-9.37385298e-03, 5.72658470e-03, 1.50206508e-02,
1.34813692e-02, -6.06315304e-03, 1.37309004e-02,
1.81529596e-02, 9.99111217e-03, 9.44260042e-03,
7.56133953e-03, 1.75948571e-02, -9.92051209e-04,
-1.06440932e-02, -3.96363717e-03],
[-0.00000000e+00, -0.00000000e+00, 0.00000000e+00,
-0.00000000e+00, 0.00000000e+00, -0.00000000e+00,
-0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
-0.00000000e+00, 0.00000000e+00, -0.00000000e+00,
-0.00000000e+00, 0.00000000e+00, -0.00000000e+00,
0.00000000e+00, -0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00],
[-0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-0.00000000e+00, 0.00000000e+00, -0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, -0.00000000e+00],
[-0.00000000e+00, -0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
-0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
-0.00000000e+00, 0.00000000e+00, -0.00000000e+00,
-0.00000000e+00, 0.00000000e+00, -0.00000000e+00,
0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
0.00000000e+00, 0.00000000e+00]], dtype=float32)
#单头注意力机制,多头还需进行Concat
np.sum(np.array(output1).reshape([18,-1]),axis=-2)
array([ 0.12678438, 0.21699306, -0.27027592, 0.29926202, -0.20059839,
-0.3046291 , 0.21958958, -0.0223454 , 0.70926446, -0.14991458,
-0.02976831, -0.17433114, 0.16313875, -0.197182 , -0.32607505,
0.15890414, -0.10676249, 0.03057403, -0.22585806, 0.02815294],
dtype=float32)
output[0][0]
tensor([ 0.1268, 0.2170, -0.2703, 0.2993, -0.2006, -0.3046, 0.2196, -0.0223,
0.7093, -0.1499, -0.0298, -0.1743, 0.1631, -0.1972, -0.3261, 0.1589,
-0.1068, 0.0306, -0.2259, 0.0282], grad_fn=<SelectBackward0>)