padding_mask实战

padding_mask实战

  • 在使用 Transformer 模型时,padding_mask 是非常重要的一部分。它用于在处理变长序列时,标识出填充位置(padding)的掩码,以避免模型将填充的零值考虑在内。

  • 以下是关于 padding_mask 的一些详细信息和如何在代码中使用它的示例:

什么是 Padding Mask?

Padding Mask 是一个布尔掩码矩阵,通常用于注意力机制中,来指示哪些位置是填充的。模型在计算注意力时会忽略这些填充位置,以避免对无意义的填充值进行处理,从而提高模型的有效性。

如何使用 Padding Mask?

在 Transformer 中,padding_mask 可以在多头自注意力(Multi-Head Self-Attention)层中使用。具体来说,它可以通过将填充位置的注意力权重设为负无穷来实现,从而使得这些位置在 softmax 操作之后的权重为零。

import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.tensor([[   2, 4556, 4573, 4569, 4842, 9210, 4675, 4790, 4590, 4668, 4684, 4642,
         4979,    5,    3,    1,    1,    1],
        [   2, 4726, 5209, 4600, 4561, 4572, 4663,    5,    3,    1,    1,    1,
            1,    1,    1,    1,    1,    1],
        [   2, 4567, 4662,   15, 4556, 4857, 4559, 4556, 4615,   15, 4603, 4674,
         4569, 4659,    5,    3,    1,    1],
        [   2, 4556, 6020, 4562,    6, 4563, 4694, 4633, 7903, 4581, 4683, 4560,
         4563, 5314, 4833, 4664,    5,    3],
        [   2, 4591, 4578, 6490, 4609, 4570, 4557, 8561,    5,    3,    1,    1,
            1,    1,    1,    1,    1,    1],
        [   2, 4556, 4562, 4870, 4575, 4563, 7318,   15, 5222, 4680, 5599, 6106,
         4628,    5,    3,    1,    1,    1],
        [   2, 4558, 4565,    6, 4768, 4588, 4585, 4581,  394,    6, 4566, 5227,
            5,    3,    1,    1,    1,    1],
        [   2, 4556, 5062, 4585, 4571, 2647,  941, 4592,    0, 4944,    5,    3,
            1,    1,    1,    1,    1,    1],
        [   2, 4556, 4714, 4573, 4560, 4758, 4661, 4560, 7566, 4559,    0, 4606,
           21, 4557, 7764, 4658,    5,    3],
        [   2, 4556, 4615, 4574,    6, 4557,    0, 4588, 4579,   21, 4563, 4582,
            5,    3,    1,    1,    1,    1]])
x.shape
torch.Size([10, 18])
def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)
src_mask = get_pad_mask(x, 1)
src_mask
tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
          False, False, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          False, False, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True, False, False, False, False]]])
src_mask.shape
torch.Size([10, 1, 18])
batch_size = 10
seq_length = 18
embedding_size = 16
x_word_embedding = torch.randn([batch_size,seq_length,embedding_size])#在词表里是一个随机生成词表大小的权重矩阵,然后每个序号在里面查询。
x_word_embedding.shape
torch.Size([10, 18, 16])
for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        if x[i][j] == 1:
            x_word_embedding[i][j] = 0
x_word_embedding[0]
tensor([[ 0.5953, -0.7687, -0.1396,  0.0551, -0.1316,  2.1182,  0.3460, -0.2474,
         -0.7631,  0.7371,  0.2647, -0.9265, -0.0795, -0.9688, -1.9949, -1.0153],
        [-1.8476,  0.1382, -1.7616,  0.7693,  0.2297, -0.4240, -0.2053, -1.6727,
         -0.9341, -0.6124, -0.3068, -0.9722, -0.5115, -1.0359, -0.7867,  0.5991],
        [-0.6086,  0.8565, -0.5559,  2.0990, -0.5317, -0.0794,  0.4806,  0.3160,
          1.0327, -0.2135, -0.1543,  1.6888,  1.1976, -1.2593, -0.5984, -0.3219],
        [-1.2501,  0.6738,  1.1699,  0.3973,  0.0653, -1.2225, -0.2438, -0.6155,
         -0.6204, -1.8898, -0.5140, -2.0530,  0.7663,  0.4650, -1.0250, -1.1067],
        [-0.5502,  0.9821,  1.2645, -1.5540,  0.8937,  0.3531, -1.0345,  1.5634,
         -0.4247,  0.8320, -1.1106,  1.2879,  0.2906,  2.3099, -1.6235, -0.9464],
        [-0.2324, -0.3651,  0.4145,  0.9106,  0.1187, -0.2084,  0.2424,  1.0556,
         -2.1255,  1.6913, -1.0605,  1.6755,  1.0412,  0.2087,  1.6073,  0.1279],
        [ 0.2160,  0.7791, -0.4280, -0.0124, -2.2770, -0.9033, -0.1782, -1.2013,
          0.9789,  0.2439,  0.4394, -0.5341, -0.5609,  1.8852, -0.3370, -0.1769],
        [ 2.1392,  1.3648,  0.1819,  0.8782, -1.5937, -0.3102, -0.0304, -0.2514,
         -0.7271,  1.2352, -0.7463,  0.9655,  1.7330, -0.0797,  1.7071,  0.0998],
        [-0.3247,  0.2992, -1.9922,  0.1976, -0.1732,  1.1075, -3.4594, -0.1730,
          1.4339,  0.9275,  0.0803,  0.4825, -1.1535, -1.1750,  1.0148,  0.3328],
        [ 0.5756, -0.5909, -1.0873, -0.5838,  0.1695,  0.7564, -0.1989, -1.0383,
         -0.1502,  1.5897, -0.5680, -3.2284, -0.8693, -2.0997, -0.1233,  1.1241],
        [ 0.5985, -0.6717, -0.4839, -1.4869,  0.2176, -0.4596,  2.2802,  0.3979,
         -0.5376, -0.5777,  1.2366,  0.1043, -0.6600,  0.1091, -1.1323,  0.2859],
        [ 1.3520, -0.4221,  0.8972, -1.2319,  0.2628, -0.3468,  0.4029,  0.3405,
         -1.1161,  0.5008,  1.0744, -1.6060,  2.6297,  1.6144,  0.2159,  0.6276],
        [ 0.3464, -1.5550, -1.5556, -0.3224,  0.6951,  1.0998,  1.8592,  0.7744,
          0.9877,  0.7624,  2.5627,  0.5629, -1.3892,  0.4832, -0.1809, -0.3011],
        [-0.6122, -0.1236,  0.1143, -0.4748, -1.1046, -1.6141, -1.2581, -0.6381,
         -1.3857, -1.8965,  0.4942,  1.1739,  1.0684, -0.9261, -0.6500, -0.2302],
        [ 0.3307, -0.2460,  1.9470,  0.1465, -0.9497,  0.3327, -0.8901,  1.1441,
         -0.4879,  1.0638, -1.4439, -2.0664, -0.8945,  0.5613,  0.7411, -0.2546],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])
x_postion_embedding = torch.randn([18,16]).unsqueeze(0)
x_postion_embedding 
tensor([[[ 9.0368e-01,  5.8748e-01, -1.1009e+00, -1.0503e-01, -2.1601e+00,
           1.5621e+00,  4.6033e-01,  1.2967e+00, -2.1629e+00, -7.8287e-02,
          -9.9411e-01,  7.6217e-01,  4.9960e-02,  4.0106e-01, -2.0348e+00,
           4.9369e-01],
         [ 1.0890e+00,  2.2973e+00,  9.6957e-01,  1.3043e+00,  6.0618e-02,
           1.9305e+00,  3.8956e-02, -5.6996e-01, -1.2326e+00, -9.9662e-01,
          -6.2974e-01, -1.7874e+00,  2.6497e-01, -3.9136e-01, -1.5453e+00,
          -5.2771e-01],
         [-8.4272e-01, -1.0274e+00,  3.9564e-01,  2.3713e-02, -8.5774e-02,
          -1.9809e+00,  6.5302e-01, -9.5564e-02, -1.3990e-01, -2.4967e-01,
          -1.3083e-01,  1.2046e+00,  9.8682e-01, -8.6967e-01,  1.8375e-01,
           4.2060e-02],
         [ 1.3766e+00, -1.2898e+00, -5.6401e-01,  8.1396e-01,  8.3389e-01,
          -1.6490e+00, -1.0349e+00,  3.0295e-01,  3.5253e-01,  1.4929e+00,
           5.5201e-01,  1.2799e+00, -4.9722e-01, -1.5553e-01,  3.7034e-01,
           1.8389e-01],
         [ 5.6875e-01, -1.9672e+00, -1.7386e+00,  7.6477e-01, -1.7213e+00,
          -9.3602e-01,  9.6785e-01,  4.8722e-01,  6.3193e-02,  3.8820e-01,
          -5.6842e-02,  5.4492e-01,  8.9291e-02, -7.8848e-01, -7.5108e-01,
          -2.1830e-01],
         [-1.3706e+00, -3.7459e-01, -1.4454e+00, -6.9120e-02, -8.5179e-03,
           1.0005e+00, -7.9239e-01,  3.5959e-01, -5.8618e-01,  1.3782e+00,
          -1.5627e-01, -2.4377e-01,  5.0945e-01, -6.9910e-01,  3.4893e-01,
          -8.7715e-01],
         [-9.7867e-01, -9.5280e-01,  2.3949e-01, -5.0913e-01,  2.8120e-01,
           6.5280e-01,  1.5024e+00, -8.8424e-01, -4.7504e-01,  4.8769e-01,
           1.1893e+00, -6.1249e-01, -6.9277e-01, -2.6621e-01, -1.0576e+00,
           5.5870e-01],
         [ 1.5670e+00,  1.4179e+00, -4.3079e-01,  9.2860e-03,  9.1036e-02,
          -8.8635e-01, -7.3382e-01,  9.9911e-02,  1.0422e+00,  3.3808e-02,
          -4.3342e-01, -4.2554e-01, -1.2664e+00,  6.7073e-01, -4.3843e-01,
           2.4629e-01],
         [-4.0307e-01, -8.0981e-01, -7.3736e-01,  7.0317e-01, -6.7180e-01,
           3.4013e-02,  1.8024e-01,  6.2083e-01, -3.4933e-01, -6.2524e-01,
          -3.5893e-01,  1.1627e+00, -2.3080e-01, -9.5184e-01,  7.1967e-01,
          -6.2196e-01],
         [ 5.5927e-01, -8.6444e-01,  1.4131e-01,  1.0069e+00, -5.3246e-01,
          -1.3804e+00, -1.2412e+00, -6.2083e-01,  1.2844e+00, -8.0976e-01,
          -1.3851e+00, -2.3795e-01,  7.9758e-01, -1.4378e+00, -1.5737e+00,
          -8.9958e-01],
         [-1.7257e-01, -8.2184e-01, -1.2956e+00,  1.5484e+00, -5.5334e-02,
           9.2041e-01, -8.8723e-01,  6.2453e-01, -1.0932e-02,  1.8479e+00,
           1.9423e-01,  9.7021e-01,  2.9241e-01,  7.9268e-01, -5.6227e-01,
          -1.6743e+00],
         [ 1.2613e-01,  1.2770e-01,  7.3058e-01, -6.5678e-01, -9.4056e-01,
          -1.2946e+00, -7.7965e-01,  1.2064e-01, -1.8496e+00,  1.9518e+00,
          -1.6807e+00,  9.9174e-01,  5.3432e-01, -2.4413e-01, -9.0537e-01,
           8.8707e-01],
         [ 1.6266e+00, -1.2979e+00,  2.1201e-01, -1.2506e+00, -2.5751e-01,
          -1.4188e+00,  6.8915e-01,  7.5310e-01,  1.0268e+00, -4.7283e-02,
           6.0086e-01, -2.2722e+00,  1.4796e+00,  1.6694e-01,  9.7610e-01,
          -3.5627e-01],
         [ 6.9922e-01, -2.8308e-01,  1.8916e+00,  4.8927e-01, -1.2071e+00,
           5.6262e-01,  7.0220e-01, -9.3077e-01, -4.1099e-01, -8.8629e-01,
          -1.4075e+00,  1.4386e-01, -1.2417e+00,  3.6809e-02,  3.5345e-01,
          -1.0575e+00],
         [-6.3683e-01,  9.1762e-01,  4.9119e-01, -2.2506e+00, -8.0021e-01,
           2.0614e-01, -1.2295e+00, -1.1974e+00,  2.1372e-01, -9.5220e-01,
          -1.9652e+00, -4.0965e-02,  1.4382e+00, -1.1599e+00,  6.7839e-01,
           9.6810e-01],
         [ 3.6581e-01, -3.6475e-01,  7.9087e-01, -9.9523e-01,  1.0741e+00,
           3.9176e-01,  3.8599e-01,  1.0822e+00, -2.7792e-01, -1.0478e-01,
           1.7455e+00, -1.2391e+00,  4.7921e-01,  5.7139e-01,  7.3661e-01,
          -3.4478e-01],
         [ 1.3962e-01,  1.7028e-01, -3.8022e-01,  2.7366e+00, -4.0408e-01,
           2.1577e+00, -1.1733e-01,  2.2820e+00, -1.0575e+00, -7.8049e-02,
          -1.8555e+00, -1.5639e+00, -2.2194e-03,  9.1327e-01, -6.8152e-01,
           4.3902e-01],
         [ 1.3193e+00,  2.4797e-01,  7.9764e-01,  9.9411e-02, -5.5226e-01,
          -6.8628e-02, -3.8057e-01,  1.5772e+00,  8.9885e-01, -6.5663e-01,
           2.7259e+00,  5.9226e-01, -2.0406e+00,  5.5416e-01, -2.2941e-01,
          -4.6528e-01]]])
x_postion_embedding.shape#每个句子的相同位置都是使用的相同的位置向量。
torch.Size([1, 18, 16])
x_embedding = x_word_embedding + x_postion_embedding
x_embedding
tensor([[[ 1.4990, -0.1813, -1.2405,  ..., -0.5678, -4.0296, -0.5216],
         [-0.7586,  2.4355, -0.7921,  ..., -1.4273, -2.3320,  0.0714],
         [-1.4513, -0.1709, -0.1603,  ..., -2.1290, -0.4146, -0.2798],
         ...,
         [ 0.3658, -0.3647,  0.7909,  ...,  0.5714,  0.7366, -0.3448],
         [ 0.1396,  0.1703, -0.3802,  ...,  0.9133, -0.6815,  0.4390],
         [ 1.3193,  0.2480,  0.7976,  ...,  0.5542, -0.2294, -0.4653]],

        [[ 1.2682, -0.8150, -0.8226,  ...,  0.0490, -1.8821,  0.6543],
         [ 1.2540,  2.3416,  0.3683,  ...,  0.4880,  0.1284, -1.2846],
         [-0.9080,  0.3243, -2.1975,  ..., -2.0137,  0.2239, -0.0376],
         ...,
         [ 0.3658, -0.3647,  0.7909,  ...,  0.5714,  0.7366, -0.3448],
         [ 0.1396,  0.1703, -0.3802,  ...,  0.9133, -0.6815,  0.4390],
         [ 1.3193,  0.2480,  0.7976,  ...,  0.5542, -0.2294, -0.4653]],

        [[ 1.8732,  0.5552, -1.7469,  ...,  1.8612, -2.0750,  0.5315],
         [ 0.3223,  2.0160,  2.4301,  ..., -1.7063, -1.3378,  0.2593],
         [-0.3029, -1.4872, -1.3034,  ...,  0.5289, -0.0337,  0.4721],
         ...,
         [-0.1261, -1.1610,  0.8987,  ...,  1.1619,  2.2077, -0.1100],
         [ 0.1396,  0.1703, -0.3802,  ...,  0.9133, -0.6815,  0.4390],
         [ 1.3193,  0.2480,  0.7976,  ...,  0.5542, -0.2294, -0.4653]],

        ...,

        [[ 0.6126,  0.9536, -1.1994,  ...,  1.2692, -0.6369,  1.7213],
         [ 2.0293,  2.3032,  0.5438,  ..., -1.0498, -0.9796, -0.5360],
         [ 0.2768, -1.1700, -3.0397,  ..., -0.9378,  0.1812, -0.1183],
         ...,
         [ 0.3658, -0.3647,  0.7909,  ...,  0.5714,  0.7366, -0.3448],
         [ 0.1396,  0.1703, -0.3802,  ...,  0.9133, -0.6815,  0.4390],
         [ 1.3193,  0.2480,  0.7976,  ...,  0.5542, -0.2294, -0.4653]],

        [[ 0.2253, -0.3884, -1.9119,  ..., -1.1351, -1.7859,  0.1309],
         [ 0.7697,  1.8575,  0.9080,  ..., -0.9043,  0.5719, -2.3765],
         [-0.3397, -1.2380,  0.0837,  ..., -1.6820,  0.8802, -0.0421],
         ...,
         [ 0.5816, -1.1825,  0.6010,  ...,  0.6708, -0.8309, -0.3200],
         [ 1.7076, -0.1556, -0.1022,  ...,  1.1021,  0.6713,  0.8018],
         [ 2.7562, -0.3169,  0.5713,  ...,  0.3990, -0.1851, -0.3272]],

        [[ 2.5735,  1.8996,  0.4187,  ..., -0.0281, -1.4581,  0.0753],
         [-0.0460,  1.7525,  1.3994,  ..., -0.5160, -2.1990, -2.7955],
         [ 0.3476, -0.7091, -1.4106,  ..., -1.6300, -0.2675,  0.6924],
         ...,
         [ 0.3658, -0.3647,  0.7909,  ...,  0.5714,  0.7366, -0.3448],
         [ 0.1396,  0.1703, -0.3802,  ...,  0.9133, -0.6815,  0.4390],
         [ 1.3193,  0.2480,  0.7976,  ...,  0.5542, -0.2294, -0.4653]]])
x_embedding.shape
torch.Size([10, 18, 16])
batch_size = 10
seq_length = 18
embedding_size = 16
q_d=k_d = v_d= 20
q = nn.Linear(embedding_size,q_d,bias=False)(x_embedding)
k = nn.Linear(embedding_size,k_d,bias=False)(x_embedding)
v = nn.Linear(embedding_size,v_d,bias=False)(x_embedding)
q.shape,k.shape,v.shape
(torch.Size([10, 18, 20]), torch.Size([10, 18, 20]), torch.Size([10, 18, 20]))
attn = torch.matmul(q / 4, k.transpose(1, 2))
attn.shape
torch.Size([10, 18, 18])
attn[0].shape
torch.Size([18, 18])
attn = attn.masked_fill(src_mask == 0, -1e9)
attn[0]
tensor([[-6.0690e-01, -3.7683e-01,  1.5225e+00,  9.6109e-01,  9.9332e-01,
          1.8536e-01,  7.5841e-01,  8.1171e-01, -4.7227e-01, -1.6088e+00,
          6.3255e-01, -2.4172e-01, -3.0882e-01,  1.3427e+00, -1.1389e+00,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-3.0765e-01, -1.5715e-01,  4.4040e-01,  4.1949e-01, -1.4558e-01,
         -2.8926e-01,  5.3010e-01,  5.7046e-02, -6.6750e-01, -1.2537e+00,
          2.4509e-01, -4.4911e-01,  4.4740e-01,  1.5970e+00, -6.1908e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.1161e-01,  1.1511e+00, -1.7925e+00,  2.2255e-01, -1.0840e+00,
         -7.0142e-02, -5.0922e-01,  1.7571e-01, -5.4323e-01, -4.8321e-01,
         -1.2211e+00,  3.6905e-01, -1.1800e+00,  8.3300e-01,  1.5945e+00,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 5.4233e-02, -5.0000e-02, -6.8838e-01, -1.3039e-01, -5.2410e-01,
         -5.5713e-01,  1.9288e-02,  5.0029e-02, -7.9929e-02,  4.1053e-01,
         -3.4310e-01, -1.2978e-01,  3.6358e-01,  3.4465e-01,  5.2046e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-8.8978e-02,  8.0073e-01,  7.4081e-02,  4.5606e-01,  1.7942e-01,
         -5.0651e-02,  2.7160e-01,  6.4326e-01, -3.5353e-01, -1.7773e-01,
         -2.0484e-01,  6.1417e-03, -1.0183e+00,  5.1489e-01,  1.6607e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-3.6375e-01,  9.5096e-01, -1.0274e+00, -1.5157e-01, -1.2807e+00,
         -2.5263e-01,  3.9632e-02,  6.1543e-01,  4.6237e-01,  3.3684e-01,
         -6.3857e-01, -4.8063e-01, -2.0782e-01, -5.5224e-01,  1.7934e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 2.8373e-01, -1.0136e-01,  1.1654e+00,  9.3107e-01,  1.0568e+00,
         -7.8805e-02,  5.1141e-01, -9.4328e-01, -6.2343e-01,  4.3079e-01,
          1.0431e+00, -1.0827e-01,  7.1189e-01,  3.3160e-01, -1.1183e+00,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 4.8535e-01, -1.1663e+00,  6.6636e-01, -9.6097e-01,  8.5208e-01,
          9.9473e-01, -5.2427e-02,  4.5534e-01,  4.1307e-01, -8.3508e-01,
          1.5996e+00,  8.0137e-01,  1.3822e+00, -2.4668e+00, -2.2721e+00,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-4.1034e-02, -6.5552e-01,  2.6044e-01,  2.9065e-01, -3.2397e-01,
          2.3684e-01,  8.1972e-03,  1.0512e+00, -7.9683e-01,  4.2159e-01,
         -1.9782e-01,  1.7893e+00,  7.8592e-01,  4.2542e-02,  2.4891e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 6.4500e-02, -1.8950e+00, -2.1693e-01,  1.0840e-01,  2.2477e-01,
          6.7428e-01, -3.7683e-01,  3.2650e-01,  5.2687e-01,  5.6838e-01,
          1.4969e-01,  1.8243e+00,  1.1127e+00, -2.8123e-01,  4.6841e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 8.2095e-02,  4.1222e-01,  3.1565e-01,  8.2741e-01,  4.0937e-01,
         -6.6774e-01,  9.6502e-01,  4.8628e-02, -1.1333e+00, -2.4812e-01,
          2.4803e-01, -2.0652e-01,  9.9369e-03,  1.1184e+00, -2.5702e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-8.6849e-01, -1.8334e-01, -4.9383e-01, -2.5585e-01, -3.5734e-01,
          6.4596e-01, -6.3123e-01, -1.7421e-01,  2.5318e+00, -2.5341e-01,
          5.8007e-01, -1.5293e+00,  2.2613e-01, -2.1622e+00, -1.7305e+00,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 4.9297e-01, -8.2548e-01,  3.7337e-01,  3.2121e-01,  1.0958e+00,
         -3.7185e-01,  4.2499e-01, -9.7494e-01,  1.0195e-01,  8.9203e-01,
          6.9874e-01, -1.8231e-01,  7.8013e-01, -3.2680e-01, -2.0724e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.1754e-02, -2.1956e-01, -1.8653e-01,  5.9911e-02,  4.4472e-01,
          1.3811e-01, -6.3979e-01, -5.6150e-02,  5.5480e-01,  2.0387e-01,
         -7.0942e-01, -1.7152e-01, -1.0433e+00,  1.2044e+00,  1.6884e+00,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0712e-01, -6.6093e-01,  1.4028e-02, -9.5958e-01,  7.4170e-02,
          1.2432e+00, -1.2309e+00,  3.3976e-01,  2.2499e+00,  9.1287e-01,
         -3.0829e-01,  3.4668e-01, -6.4126e-01, -1.6625e+00,  6.1514e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-1.0567e-01,  6.0719e-02, -2.0683e-01, -2.4056e-01, -1.4933e-01,
         -3.5809e-01,  1.3756e-01, -2.6875e-01,  4.9080e-01,  1.6693e-01,
         -1.7357e-01, -8.6498e-01, -1.1757e-01, -3.3879e-02,  1.9470e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-2.3758e-01,  4.7806e-02,  1.3244e+00, -1.9672e-01, -5.8326e-02,
         -4.1213e-01,  3.4664e-01,  6.9897e-01, -6.6317e-01, -6.6268e-01,
          3.4629e-01, -4.4367e-01,  4.2473e-01,  5.9423e-02, -1.2536e+00,
         -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [ 1.0423e-01, -4.4003e-01,  9.3847e-01, -1.3148e-01,  4.3072e-01,
         -2.8801e-01,  4.1276e-01,  3.0477e-01, -9.8558e-01, -1.9937e-01,
          1.7873e-01,  3.9120e-01,  3.3835e-01,  3.1912e-01, -1.6030e-01,
         -1.0000e+09, -1.0000e+09, -1.0000e+09]], grad_fn=<SelectBackward0>)
scores = F.softmax(attn,dim=-1)
scores[0]
tensor([[0.0217, 0.0273, 0.1827, 0.1042, 0.1076, 0.0480, 0.0851, 0.0897, 0.0248,
         0.0080, 0.0750, 0.0313, 0.0293, 0.1526, 0.0128, 0.0000, 0.0000, 0.0000],
        [0.0391, 0.0455, 0.0827, 0.0810, 0.0460, 0.0399, 0.0904, 0.0563, 0.0273,
         0.0152, 0.0680, 0.0340, 0.0832, 0.2628, 0.0287, 0.0000, 0.0000, 0.0000],
        [0.0471, 0.1663, 0.0088, 0.0657, 0.0178, 0.0490, 0.0316, 0.0627, 0.0306,
         0.0325, 0.0155, 0.0761, 0.0162, 0.1210, 0.2592, 0.0000, 0.0000, 0.0000],
        [0.0696, 0.0627, 0.0331, 0.0578, 0.0390, 0.0378, 0.0672, 0.0693, 0.0608,
         0.0994, 0.0468, 0.0579, 0.0948, 0.0930, 0.1109, 0.0000, 0.0000, 0.0000],
        [0.0516, 0.1256, 0.0607, 0.0890, 0.0675, 0.0536, 0.0740, 0.1073, 0.0396,
         0.0472, 0.0459, 0.0567, 0.0204, 0.0944, 0.0666, 0.0000, 0.0000, 0.0000],
        [0.0458, 0.1707, 0.0236, 0.0567, 0.0183, 0.0512, 0.0686, 0.1220, 0.1047,
         0.0923, 0.0348, 0.0408, 0.0536, 0.0380, 0.0789, 0.0000, 0.0000, 0.0000],
        [0.0567, 0.0386, 0.1370, 0.1084, 0.1229, 0.0395, 0.0713, 0.0166, 0.0229,
         0.0657, 0.1213, 0.0383, 0.0871, 0.0595, 0.0140, 0.0000, 0.0000, 0.0000],
        [0.0646, 0.0124, 0.0775, 0.0152, 0.0933, 0.1076, 0.0378, 0.0627, 0.0601,
         0.0173, 0.1970, 0.0887, 0.1585, 0.0034, 0.0041, 0.0000, 0.0000, 0.0000],
        [0.0412, 0.0223, 0.0557, 0.0575, 0.0311, 0.0544, 0.0433, 0.1229, 0.0194,
         0.0655, 0.0353, 0.2572, 0.0943, 0.0448, 0.0551, 0.0000, 0.0000, 0.0000],
        [0.0433, 0.0061, 0.0327, 0.0452, 0.0508, 0.0797, 0.0278, 0.0563, 0.0687,
         0.0717, 0.0471, 0.2516, 0.1235, 0.0306, 0.0648, 0.0000, 0.0000, 0.0000],
        [0.0543, 0.0756, 0.0686, 0.1144, 0.0753, 0.0257, 0.1313, 0.0525, 0.0161,
         0.0390, 0.0641, 0.0407, 0.0505, 0.1531, 0.0387, 0.0000, 0.0000, 0.0000],
        [0.0178, 0.0354, 0.0260, 0.0329, 0.0297, 0.0811, 0.0226, 0.0357, 0.5348,
         0.0330, 0.0759, 0.0092, 0.0533, 0.0049, 0.0075, 0.0000, 0.0000, 0.0000],
        [0.0795, 0.0213, 0.0705, 0.0669, 0.1452, 0.0335, 0.0742, 0.0183, 0.0537,
         0.1184, 0.0976, 0.0404, 0.1059, 0.0350, 0.0395, 0.0000, 0.0000, 0.0000],
        [0.0465, 0.0378, 0.0390, 0.0499, 0.0734, 0.0540, 0.0248, 0.0445, 0.0819,
         0.0577, 0.0231, 0.0396, 0.0166, 0.1568, 0.2545, 0.0000, 0.0000, 0.0000],
        [0.0349, 0.0201, 0.0394, 0.0149, 0.0418, 0.1347, 0.0113, 0.0546, 0.3685,
         0.0968, 0.0285, 0.0549, 0.0205, 0.0074, 0.0719, 0.0000, 0.0000, 0.0000],
        [0.0634, 0.0749, 0.0573, 0.0554, 0.0607, 0.0493, 0.0809, 0.0539, 0.1152,
         0.0833, 0.0593, 0.0297, 0.0627, 0.0682, 0.0857, 0.0000, 0.0000, 0.0000],
        [0.0453, 0.0602, 0.2159, 0.0472, 0.0542, 0.0380, 0.0812, 0.1155, 0.0296,
         0.0296, 0.0812, 0.0369, 0.0878, 0.0609, 0.0164, 0.0000, 0.0000, 0.0000],
        [0.0622, 0.0361, 0.1433, 0.0491, 0.0862, 0.0420, 0.0847, 0.0760, 0.0209,
         0.0459, 0.0670, 0.0829, 0.0786, 0.0771, 0.0478, 0.0000, 0.0000, 0.0000]],
       grad_fn=<SelectBackward0>)
sum(scores[0][0])
tensor(1.0000, grad_fn=<AddBackward0>)
scores.shape
torch.Size([10, 18, 18])
v.shape
torch.Size([10, 18, 20])
v[0].shape
torch.Size([18, 20])
scores[0]
tensor([[0.0217, 0.0273, 0.1827, 0.1042, 0.1076, 0.0480, 0.0851, 0.0897, 0.0248,
         0.0080, 0.0750, 0.0313, 0.0293, 0.1526, 0.0128, 0.0000, 0.0000, 0.0000],
        [0.0391, 0.0455, 0.0827, 0.0810, 0.0460, 0.0399, 0.0904, 0.0563, 0.0273,
         0.0152, 0.0680, 0.0340, 0.0832, 0.2628, 0.0287, 0.0000, 0.0000, 0.0000],
        [0.0471, 0.1663, 0.0088, 0.0657, 0.0178, 0.0490, 0.0316, 0.0627, 0.0306,
         0.0325, 0.0155, 0.0761, 0.0162, 0.1210, 0.2592, 0.0000, 0.0000, 0.0000],
        [0.0696, 0.0627, 0.0331, 0.0578, 0.0390, 0.0378, 0.0672, 0.0693, 0.0608,
         0.0994, 0.0468, 0.0579, 0.0948, 0.0930, 0.1109, 0.0000, 0.0000, 0.0000],
        [0.0516, 0.1256, 0.0607, 0.0890, 0.0675, 0.0536, 0.0740, 0.1073, 0.0396,
         0.0472, 0.0459, 0.0567, 0.0204, 0.0944, 0.0666, 0.0000, 0.0000, 0.0000],
        [0.0458, 0.1707, 0.0236, 0.0567, 0.0183, 0.0512, 0.0686, 0.1220, 0.1047,
         0.0923, 0.0348, 0.0408, 0.0536, 0.0380, 0.0789, 0.0000, 0.0000, 0.0000],
        [0.0567, 0.0386, 0.1370, 0.1084, 0.1229, 0.0395, 0.0713, 0.0166, 0.0229,
         0.0657, 0.1213, 0.0383, 0.0871, 0.0595, 0.0140, 0.0000, 0.0000, 0.0000],
        [0.0646, 0.0124, 0.0775, 0.0152, 0.0933, 0.1076, 0.0378, 0.0627, 0.0601,
         0.0173, 0.1970, 0.0887, 0.1585, 0.0034, 0.0041, 0.0000, 0.0000, 0.0000],
        [0.0412, 0.0223, 0.0557, 0.0575, 0.0311, 0.0544, 0.0433, 0.1229, 0.0194,
         0.0655, 0.0353, 0.2572, 0.0943, 0.0448, 0.0551, 0.0000, 0.0000, 0.0000],
        [0.0433, 0.0061, 0.0327, 0.0452, 0.0508, 0.0797, 0.0278, 0.0563, 0.0687,
         0.0717, 0.0471, 0.2516, 0.1235, 0.0306, 0.0648, 0.0000, 0.0000, 0.0000],
        [0.0543, 0.0756, 0.0686, 0.1144, 0.0753, 0.0257, 0.1313, 0.0525, 0.0161,
         0.0390, 0.0641, 0.0407, 0.0505, 0.1531, 0.0387, 0.0000, 0.0000, 0.0000],
        [0.0178, 0.0354, 0.0260, 0.0329, 0.0297, 0.0811, 0.0226, 0.0357, 0.5348,
         0.0330, 0.0759, 0.0092, 0.0533, 0.0049, 0.0075, 0.0000, 0.0000, 0.0000],
        [0.0795, 0.0213, 0.0705, 0.0669, 0.1452, 0.0335, 0.0742, 0.0183, 0.0537,
         0.1184, 0.0976, 0.0404, 0.1059, 0.0350, 0.0395, 0.0000, 0.0000, 0.0000],
        [0.0465, 0.0378, 0.0390, 0.0499, 0.0734, 0.0540, 0.0248, 0.0445, 0.0819,
         0.0577, 0.0231, 0.0396, 0.0166, 0.1568, 0.2545, 0.0000, 0.0000, 0.0000],
        [0.0349, 0.0201, 0.0394, 0.0149, 0.0418, 0.1347, 0.0113, 0.0546, 0.3685,
         0.0968, 0.0285, 0.0549, 0.0205, 0.0074, 0.0719, 0.0000, 0.0000, 0.0000],
        [0.0634, 0.0749, 0.0573, 0.0554, 0.0607, 0.0493, 0.0809, 0.0539, 0.1152,
         0.0833, 0.0593, 0.0297, 0.0627, 0.0682, 0.0857, 0.0000, 0.0000, 0.0000],
        [0.0453, 0.0602, 0.2159, 0.0472, 0.0542, 0.0380, 0.0812, 0.1155, 0.0296,
         0.0296, 0.0812, 0.0369, 0.0878, 0.0609, 0.0164, 0.0000, 0.0000, 0.0000],
        [0.0622, 0.0361, 0.1433, 0.0491, 0.0862, 0.0420, 0.0847, 0.0760, 0.0209,
         0.0459, 0.0670, 0.0829, 0.0786, 0.0771, 0.0478, 0.0000, 0.0000, 0.0000]],
       grad_fn=<SelectBackward0>)
output = torch.matmul(scores, v)
output.shape
torch.Size([10, 18, 20])
output[0][0]
tensor([ 0.1268,  0.2170, -0.2703,  0.2993, -0.2006, -0.3046,  0.2196, -0.0223,
         0.7093, -0.1499, -0.0298, -0.1743,  0.1631, -0.1972, -0.3261,  0.1589,
        -0.1068,  0.0306, -0.2259,  0.0282], grad_fn=<SelectBackward0>)
scores.shape
torch.Size([10, 18, 18])
scores[0][0]
tensor([0.0217, 0.0273, 0.1827, 0.1042, 0.1076, 0.0480, 0.0851, 0.0897, 0.0248,
        0.0080, 0.0750, 0.0313, 0.0293, 0.1526, 0.0128, 0.0000, 0.0000, 0.0000],
       grad_fn=<SelectBackward0>)
v.shape
torch.Size([10, 18, 20])
#不用矩阵算是为了让大家更清楚它的计算过程,实际应用中我们是矩阵运算。
output1 = []
for i in range(len(scores[0][0])):
    output1.append(scores[0][0][i].detach().numpy()*v[0][i].detach().numpy())
output1
[array([-0.00356334,  0.01569117, -0.02043939, -0.02862189,  0.00840736,
        -0.01401182,  0.00872742,  0.03408484,  0.00019593, -0.00853396,
         0.00951094, -0.01021247,  0.01006769,  0.04058198,  0.02450397,
        -0.02066059,  0.02544492,  0.01739634, -0.00912306, -0.00440777],
       dtype=float32),
 array([ 0.02174385,  0.02181036, -0.00668943, -0.00284444, -0.01792286,
        -0.05496887,  0.00642328,  0.06035611,  0.00246261,  0.03883798,
         0.00429785, -0.00607146,  0.05036234, -0.01484454,  0.06557821,
        -0.01962684,  0.02545645,  0.03820134,  0.00964658, -0.0618248 ],
       dtype=float32),
 array([ 0.00981068,  0.08241012,  0.0435107 ,  0.26321438, -0.21280633,
        -0.065763  ,  0.22667795, -0.11789013,  0.1957324 , -0.09090293,
         0.13033353, -0.17513736, -0.02819123, -0.27770212, -0.16380529,
        -0.16696015, -0.22296761, -0.15442224, -0.14996761, -0.04623051],
       dtype=float32),
 array([ 0.05863579, -0.01563031,  0.02628604,  0.04352596,  0.00123114,
        -0.12593815, -0.05850247, -0.04619041,  0.14190896, -0.00667505,
        -0.06504327, -0.02457336,  0.01267471, -0.08988734,  0.02068387,
         0.1141569 , -0.05952757,  0.09277385,  0.04706449, -0.00445588],
       dtype=float32),
 array([-0.01174144,  0.02150498, -0.06701917, -0.08503368,  0.05980498,
         0.02390346, -0.06823318, -0.02788872,  0.00625273, -0.08193176,
        -0.08811162, -0.04064612, -0.01717038,  0.01876009, -0.05648129,
        -0.03940401,  0.04552145,  0.05847974,  0.02593314,  0.08681556],
       dtype=float32),
 array([ 0.0046064 ,  0.0753553 , -0.05663994, -0.07728799, -0.05931624,
         0.05093202,  0.03476338,  0.00069166,  0.00120347,  0.01253372,
        -0.00336849,  0.00602924,  0.03808581,  0.06855204,  0.00205304,
        -0.04280849,  0.02934068,  0.04572994, -0.10011108,  0.05379374],
       dtype=float32),
 array([ 0.00855533,  0.01972056, -0.02584991,  0.012734  ,  0.04426347,
        -0.01907244, -0.06026317,  0.08408913, -0.08700912,  0.02952116,
         0.09744457, -0.05496382,  0.05024561, -0.04637939, -0.04548782,
         0.02788595,  0.01097868, -0.02408975,  0.06984393, -0.06373438],
       dtype=float32),
 array([ 5.0336070e-02, -9.1202877e-02,  1.1917996e-02,  1.2563699e-01,
        -4.5944070e-03,  5.7880916e-03,  1.5337752e-01,  4.5809049e-02,
         8.7222785e-02, -7.0060484e-02, -1.0104631e-01,  6.6017680e-02,
         8.4173165e-02,  2.7579835e-02, -6.1708100e-02,  5.5133272e-02,
         3.2097664e-02, -2.4366794e-02,  1.0262818e-04,  8.8561904e-03],
       dtype=float32),
 array([-0.0292644 ,  0.01630488, -0.01253164, -0.0070375 , -0.03961598,
         0.05226943,  0.02208326, -0.03990135,  0.01585516, -0.00127062,
        -0.0113592 ,  0.02213981,  0.02005158,  0.01969342, -0.0070779 ,
        -0.00585135, -0.00277652, -0.00163333, -0.03733382,  0.01124388],
       dtype=float32),
 array([ 6.2298100e-03,  1.8686900e-03,  2.7403934e-03, -3.5018041e-03,
        -4.4463761e-03, -4.6206997e-03,  1.9569243e-03,  2.6030848e-03,
         1.3119811e-02,  8.0437390e-03,  2.0034057e-03, -2.2230356e-03,
         7.5591272e-03, -2.4101253e-04,  8.6891865e-03, -6.3293730e-05,
        -9.3702022e-03,  2.1057955e-03,  1.9265222e-03, -1.9525980e-03],
       dtype=float32),
 array([-0.00207176,  0.03381819, -0.00062723, -0.05788399,  0.03948465,
         0.03536145,  0.00404582, -0.01121849, -0.05850476, -0.0454405 ,
         0.00821594, -0.05255793, -0.06612497,  0.04488177, -0.0387372 ,
        -0.02739964, -0.03441686,  0.03766139,  0.0335264 ,  0.05324583],
       dtype=float32),
 array([ 0.03819073, -0.02459533, -0.03861014, -0.01953255,  0.0140466 ,
        -0.03961703, -0.01613682,  0.0444642 ,  0.01884853, -0.0135273 ,
        -0.01088831,  0.00310303, -0.00310043,  0.03728347, -0.01751549,
         0.00054721,  0.04011041,  0.01901039, -0.03186378,  0.05084344],
       dtype=float32),
 array([-0.01383471, -0.03603299,  0.0306312 , -0.02414859,  0.02562755,
         0.02530815,  0.00139578, -0.02436204, -0.02043813, -0.01283113,
         0.0208585 , -0.01336401, -0.07284833,  0.03085884, -0.04238587,
         0.02426553, -0.05287703, -0.0112904 ,  0.03723892,  0.02734692],
       dtype=float32),
 array([-0.00945639,  0.09796111, -0.14330448,  0.1691841 , -0.0494328 ,
        -0.17400241, -0.02735225, -0.03271891,  0.3773934 ,  0.0788412 ,
        -0.01655269,  0.09439775,  0.05920107, -0.06631016, -0.02382696,
         0.2521283 ,  0.04862823, -0.06399021, -0.11209722, -0.07742304],
       dtype=float32),
 array([-0.00139221, -0.00199078, -0.01365093, -0.009141  , -0.00532917,
        -0.00019725, -0.00937385,  0.00572658,  0.01502065,  0.01348137,
        -0.00606315,  0.0137309 ,  0.01815296,  0.00999111,  0.0094426 ,
         0.00756134,  0.01759486, -0.00099205, -0.01064409, -0.00396364],
       dtype=float32),
 array([-0., -0.,  0., -0.,  0., -0., -0., -0., -0., -0.,  0., -0., -0.,
         0., -0.,  0., -0.,  0.,  0.,  0.], dtype=float32),
 array([-0.,  0.,  0., -0.,  0., -0.,  0.,  0.,  0., -0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0., -0.], dtype=float32),
 array([-0., -0.,  0.,  0.,  0.,  0., -0., -0., -0., -0.,  0., -0., -0.,
         0., -0.,  0., -0., -0.,  0.,  0.], dtype=float32)]
import numpy as np
np.array(output1).reshape([18,-1])
array([[-3.56333586e-03,  1.56911686e-02, -2.04393901e-02,
        -2.86218859e-02,  8.40735529e-03, -1.40118189e-02,
         8.72741546e-03,  3.40848416e-02,  1.95933360e-04,
        -8.53395835e-03,  9.51094087e-03, -1.02124671e-02,
         1.00676939e-02,  4.05819751e-02,  2.45039724e-02,
        -2.06605904e-02,  2.54449248e-02,  1.73963401e-02,
        -9.12306458e-03, -4.40777233e-03],
       [ 2.17438508e-02,  2.18103621e-02, -6.68943441e-03,
        -2.84443772e-03, -1.79228578e-02, -5.49688675e-02,
         6.42327731e-03,  6.03561103e-02,  2.46260897e-03,
         3.88379842e-02,  4.29785158e-03, -6.07146276e-03,
         5.03623448e-02, -1.48445424e-02,  6.55782074e-02,
        -1.96268391e-02,  2.54564472e-02,  3.82013433e-02,
         9.64657590e-03, -6.18248023e-02],
       [ 9.81067866e-03,  8.24101195e-02,  4.35107015e-02,
         2.63214380e-01, -2.12806329e-01, -6.57629967e-02,
         2.26677954e-01, -1.17890127e-01,  1.95732400e-01,
        -9.09029320e-02,  1.30333528e-01, -1.75137356e-01,
        -2.81912331e-02, -2.77702123e-01, -1.63805291e-01,
        -1.66960150e-01, -2.22967610e-01, -1.54422238e-01,
        -1.49967611e-01, -4.62305136e-02],
       [ 5.86357936e-02, -1.56303123e-02,  2.62860432e-02,
         4.35259603e-02,  1.23114337e-03, -1.25938147e-01,
        -5.85024655e-02, -4.61904071e-02,  1.41908959e-01,
        -6.67505478e-03, -6.50432706e-02, -2.45733596e-02,
         1.26747135e-02, -8.98873433e-02,  2.06838697e-02,
         1.14156902e-01, -5.95275685e-02,  9.27738547e-02,
         4.70644869e-02, -4.45587980e-03],
       [-1.17414435e-02,  2.15049833e-02, -6.70191720e-02,
        -8.50336775e-02,  5.98049797e-02,  2.39034556e-02,
        -6.82331845e-02, -2.78887209e-02,  6.25273399e-03,
        -8.19317624e-02, -8.81116167e-02, -4.06461246e-02,
        -1.71703752e-02,  1.87600926e-02, -5.64812906e-02,
        -3.94040123e-02,  4.55214530e-02,  5.84797412e-02,
         2.59331409e-02,  8.68155584e-02],
       [ 4.60639922e-03,  7.53552988e-02, -5.66399395e-02,
        -7.72879943e-02, -5.93162440e-02,  5.09320162e-02,
         3.47633846e-02,  6.91655499e-04,  1.20347342e-03,
         1.25337243e-02, -3.36848991e-03,  6.02923892e-03,
         3.80858146e-02,  6.85520396e-02,  2.05303659e-03,
        -4.28084917e-02,  2.93406751e-02,  4.57299352e-02,
        -1.00111082e-01,  5.37937433e-02],
       [ 8.55532568e-03,  1.97205599e-02, -2.58499123e-02,
         1.27340043e-02,  4.42634709e-02, -1.90724358e-02,
        -6.02631718e-02,  8.40891302e-02, -8.70091245e-02,
         2.95211598e-02,  9.74445716e-02, -5.49638234e-02,
         5.02456091e-02, -4.63793948e-02, -4.54878211e-02,
         2.78859455e-02,  1.09786754e-02, -2.40897462e-02,
         6.98439255e-02, -6.37343824e-02],
       [ 5.03360704e-02, -9.12028775e-02,  1.19179962e-02,
         1.25636995e-01, -4.59440704e-03,  5.78809157e-03,
         1.53377518e-01,  4.58090492e-02,  8.72227848e-02,
        -7.00604841e-02, -1.01046309e-01,  6.60176799e-02,
         8.41731653e-02,  2.75798347e-02, -6.17081001e-02,
         5.51332720e-02,  3.20976637e-02, -2.43667942e-02,
         1.02628183e-04,  8.85619037e-03],
       [-2.92643961e-02,  1.63048804e-02, -1.25316372e-02,
        -7.03750225e-03, -3.96159813e-02,  5.22694290e-02,
         2.20832583e-02, -3.99013534e-02,  1.58551577e-02,
        -1.27062469e-03, -1.13591971e-02,  2.21398100e-02,
         2.00515836e-02,  1.96934212e-02, -7.07789976e-03,
        -5.85135119e-03, -2.77651520e-03, -1.63333223e-03,
        -3.73338237e-02,  1.12438807e-02],
       [ 6.22980995e-03,  1.86869001e-03,  2.74039339e-03,
        -3.50180408e-03, -4.44637612e-03, -4.62069968e-03,
         1.95692433e-03,  2.60308478e-03,  1.31198112e-02,
         8.04373901e-03,  2.00340571e-03, -2.22303555e-03,
         7.55912717e-03, -2.41012531e-04,  8.68918654e-03,
        -6.32937299e-05, -9.37020220e-03,  2.10579555e-03,
         1.92652224e-03, -1.95259799e-03],
       [-2.07175617e-03,  3.38181928e-02, -6.27233065e-04,
        -5.78839891e-02,  3.94846499e-02,  3.53614539e-02,
         4.04581567e-03, -1.12184929e-02, -5.85047640e-02,
        -4.54404950e-02,  8.21593590e-03, -5.25579341e-02,
        -6.61249682e-02,  4.48817722e-02, -3.87372039e-02,
        -2.73996368e-02, -3.44168618e-02,  3.76613885e-02,
         3.35264020e-02,  5.32458313e-02],
       [ 3.81907299e-02, -2.45953333e-02, -3.86101417e-02,
        -1.95325539e-02,  1.40466047e-02, -3.96170318e-02,
        -1.61368232e-02,  4.44641970e-02,  1.88485272e-02,
        -1.35272965e-02, -1.08883120e-02,  3.10303201e-03,
        -3.10042920e-03,  3.72834690e-02, -1.75154917e-02,
         5.47210744e-04,  4.01104130e-02,  1.90103948e-02,
        -3.18637751e-02,  5.08434363e-02],
       [-1.38347102e-02, -3.60329933e-02,  3.06312013e-02,
        -2.41485890e-02,  2.56275497e-02,  2.53081508e-02,
         1.39578118e-03, -2.43620351e-02, -2.04381309e-02,
        -1.28311319e-02,  2.08585002e-02, -1.33640114e-02,
        -7.28483349e-02,  3.08588445e-02, -4.23858687e-02,
         2.42655315e-02, -5.28770313e-02, -1.12903966e-02,
         3.72389220e-02,  2.73469202e-02],
       [-9.45638865e-03,  9.79611129e-02, -1.43304482e-01,
         1.69184104e-01, -4.94327955e-02, -1.74002409e-01,
        -2.73522511e-02, -3.27189118e-02,  3.77393395e-01,
         7.88412020e-02, -1.65526886e-02,  9.43977535e-02,
         5.92010729e-02, -6.63101599e-02, -2.38269642e-02,
         2.52128303e-01,  4.86282334e-02, -6.39902055e-02,
        -1.12097219e-01, -7.74230361e-02],
       [-1.39220979e-03, -1.99078326e-03, -1.36509314e-02,
        -9.14100278e-03, -5.32916794e-03, -1.97253496e-04,
        -9.37385298e-03,  5.72658470e-03,  1.50206508e-02,
         1.34813692e-02, -6.06315304e-03,  1.37309004e-02,
         1.81529596e-02,  9.99111217e-03,  9.44260042e-03,
         7.56133953e-03,  1.75948571e-02, -9.92051209e-04,
        -1.06440932e-02, -3.96363717e-03],
       [-0.00000000e+00, -0.00000000e+00,  0.00000000e+00,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
        -0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
         0.00000000e+00, -0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00],
       [-0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00, -0.00000000e+00],
       [-0.00000000e+00, -0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
         0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
         0.00000000e+00,  0.00000000e+00]], dtype=float32)
#单头注意力机制,多头还需进行Concat
np.sum(np.array(output1).reshape([18,-1]),axis=-2)
array([ 0.12678438,  0.21699306, -0.27027592,  0.29926202, -0.20059839,
       -0.3046291 ,  0.21958958, -0.0223454 ,  0.70926446, -0.14991458,
       -0.02976831, -0.17433114,  0.16313875, -0.197182  , -0.32607505,
        0.15890414, -0.10676249,  0.03057403, -0.22585806,  0.02815294],
      dtype=float32)
output[0][0]
tensor([ 0.1268,  0.2170, -0.2703,  0.2993, -0.2006, -0.3046,  0.2196, -0.0223,
         0.7093, -0.1499, -0.0298, -0.1743,  0.1631, -0.1972, -0.3261,  0.1589,
        -0.1068,  0.0306, -0.2259,  0.0282], grad_fn=<SelectBackward0>)
  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值