【源码解读】Transformer的Scaled dot product部分详解

def attention(query, key, value, mask=None, dropout=None):
    # shape:query=key=value---->[batch_size,8,max_length,64]
    
    d_k = query.size(-1)
    
    # k的纬度交换后为:[batch_size,8,64,max_length]
    # scores的纬度为:[batch_size,8,max_length,max_length]
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    
    #padding mask
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)#剖析点1
        
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

源码剖析

剖析点1:scores.masked_fill(mask == 0, -1e9)

Transformer的Scaled dot product部分的scores.masked_fill(mask == 0, -1e9)把mask矩阵中为0的位置用-1e9这个极小值填充,这样的极小值在经过softmax之后为0
关于masked_fill函数参考:https://blog.csdn.net/qq_41568188/article/details/107281395
掩码操作,用value填充tensor中与mask中值为0位置相对应的元素。
mask的形状必须与要填充的tensor形状一致。

import torch
import torch.nn.functional as F
a = torch.randn(5,6)
print(a)
x = [5,4,3,2,1]
mask = torch.zeros(5,6,dtype=torch.int)
for e_id, src_len in enumerate(x):
    mask[e_id, src_len:] = 1
print(mask)
a.data.masked_fill_(mask==0,-float('inf'))
print(a)
print(F.softmax(a,dim=-1))
#输出
tensor([[-1.8453, -0.7031,  0.0066, -1.0771, -0.5282, -0.1669],
        [ 1.0285, -0.1086, -0.9871, -1.2061, -0.7845,  1.5072],
        [ 1.8313,  2.4513, -0.1615, -1.2768, -0.5887, -1.2990],
        [ 0.4653,  0.7976,  0.2020, -0.0886, -0.9101, -2.9927],
        [-0.5556, -0.5319, -0.1768, -0.4238,  1.2213, -1.9120]])
tensor([[0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1]], dtype=torch.int32)
tensor([[   -inf,    -inf,    -inf,    -inf,    -inf, -0.1669],
        [   -inf,    -inf,    -inf,    -inf, -0.7845,  1.5072],
        [   -inf,    -inf,    -inf, -1.2768, -0.5887, -1.2990],
        [   -inf,    -inf,  0.2020, -0.0886, -0.9101, -2.9927],
        [   -inf, -0.5319, -0.1768, -0.4238,  1.2213, -1.9120]])
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0918, 0.9082],
        [0.0000, 0.0000, 0.0000, 0.2520, 0.5015, 0.2465],
        [0.0000, 0.0000, 0.4722, 0.3531, 0.1553, 0.0194],
        [0.0000, 0.1045, 0.1491, 0.1165, 0.6035, 0.0263]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值