手撕面试注意力机制

import torch
import torch.nn as nn
import torch.nn.functional as F

def SelfAttention(q, k, v, mask, dropout):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k).float())
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    scores = F.softmax(scores, dim=-1)
    if dropout is not None:
        scores = dropout(scores)
    return torch.matmul(scores, v)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, dropout=0.1):
        super().__init__()
        assert d_model % n_head == 0
        
        self.n_head = n_head
        self.d_k = d_model // n_head
        self.d_model = d_model
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, q, k, v, mask=None):
        n_batch = q.size(0)
        
        q = self.w_q(q).view(n_batch, -1, self.n_head, self.d_k).transpose(1, 2)
        k = self.w_k(k).view(n_batch, -1, self.n_head, self.d_k).transpose(1, 2)
        v = self.w_v(v).view(n_batch, -1, self.n_head, self.d_k).transpose(1, 2)
        
        if mask is not None:
            mask = mask.unsqueeze(1)
        
        attn_scores = SelfAttention(q, k, v, mask, self.dropout)
        
        attn_scores = attn_scores.transpose(1, 2).contiguous().view(n_batch, -1, self.d_model)
        
        return self.w_o(attn_scores)

def test_self_attention():
    # 设置参数
    batch_size = 2
    seq_length = 3
    d_k = 4

    # 创建输入
    q = torch.randn(batch_size, seq_length, d_k)
    k = torch.randn(batch_size, seq_length, d_k)
    v = torch.randn(batch_size, seq_length, d_k)
    
    # 创建掩码(可选)
    mask = torch.ones(batch_size, seq_length, seq_length).bool()

    # 创建dropout层
    dropout = nn.Dropout(p=0.1)

    # 调用SelfAttention函数
    output = SelfAttention(q, k, v, mask, dropout)

    # 检查输出形状
    expected_shape = (batch_size, seq_length, d_k)
    assert output.shape == expected_shape, f"SelfAttention输出形状错误。期望 {expected_shape},得到 {output.shape}"

    print("SelfAttention测试通过!输出形状正确。")
    print(f"SelfAttention输出张量:\n{output}")
    print(f"SelfAttention输出形状:{output.shape}")

def test_multi_head_attention():
    # 设置参数
    batch_size = 32
    seq_length = 10
    d_model = 512
    n_head = 8

    # 创建输入
    q = torch.randn(batch_size, seq_length, d_model)
    k = torch.randn(batch_size, seq_length, d_model)
    v = torch.randn(batch_size, seq_length, d_model)
    
    # 创建掩码(可选)
    mask = torch.ones(batch_size, seq_length, seq_length).bool()

    # 初始化 MultiHeadAttention 模型
    mha = MultiHeadAttention(n_head, d_model)

    # 前向传播
    output = mha(q, k, v, mask)

    # 检查输出形状
    expected_shape = (batch_size, seq_length, d_model)
    assert output.shape == expected_shape, f"MultiHeadAttention输出形状错误。期望 {expected_shape},得到 {output.shape}"

    print("MultiHeadAttention测试通过!输出形状正确。")
    print(f"MultiHeadAttention输出张量:\n{output}")
    print(f"MultiHeadAttention输出形状:{output.shape}")

if __name__ == "__main__":
    test_self_attention()
    print("\n" + "="*50 + "\n")
    test_multi_head_attention()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值