DiAD代码逐行理解之Attention

一、代码中的AttentionBlock类

这个 AttentionBlock 类实现了一个自注意力块,它允许空间位置之间相互注意。

class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """
    def __init__(
        self,
        channels,
        num_heads=1,
        num_head_channels=-1,
        use_checkpoint=False,
        use_new_attention_order=False,
    ):
        super().__init__()
        self.channels = channels
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels
        self.use_checkpoint = use_checkpoint
        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        if use_new_attention_order:
            # split qkv before split heads
            self.attention = QKVAttention(self.num_heads)
        else:
            # split heads before split qkv
            self.attention = QKVAttentionLegacy(self.num_heads)

        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), True)   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
        #return pt_checkpoint(self._forward, x)  # pytorch
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module
    def _forward(self, x):
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        h = self.attention(qkv)
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

构造函数 init
channels: 输入和输出张量的通道数。
num_heads: 自注意力的头数。默认为1,表示不使用多头注意力。
num_head_channels: 每个头的通道数。如果设置为-1(默认值),则使用num_heads参数来确定头的数量;否则,根据num_head_channels来分割通道,从而确定头的数量。
use_checkpoint: 是否在前向传播中使用梯度检查点(checkpointing)来节省内存。这对于训练深层模型特别有用,因为它可以减少内存消耗。
use_new_attention_order: 指定使用哪种顺序来拆分QKV(查询、键、值)和头。如果为True,则先拆分QKV,然后再拆分头;如果为False(默认值),则先拆分头,然后再拆分QKV。
构造函数还初始化了归一化层(self.norm)、QKV投影层(self.qkv)、自注意力层(self.attention,根据use_new_attention_order选择使用QKVAttention或QKVAttentionLegacy)以及输出投影层(self.proj_out)。
前向传播 forward
在这个版本的forward方法中,有一个注释掉的行,它调用了checkpoint函数来尝试使用梯度检查点来优化内存使用。然而,这一行被注释掉了,并且有一个TODO注释提到了检查checkpoint的用法。在实际应用中,你可能需要根据你的具体需求和PyTorch版本决定是否启用这一行。
注意,forward方法实际上并没有直接实现前向传播逻辑;相反,它准备了一个调用_forward方法的包装,以便(可选地)使用梯度检查点。
私有前向传播 _forward
_forward方法实现了注意力块的实际前向传播逻辑。
它首先将输入张量x重塑为[b, c, -1]形状,以便将空间维度展平。
然后,它应用归一化层、QKV投影层,并将结果传递给自注意力层。
自注意力层的输出通过输出投影层,并与原始输入x相加,实现残差连接。
最后,将结果重塑回原始的空间维度形状,并返回。
下面是对这个方法更详细的解释:

输入张量形状解析:
b, c, *spatial = x.shape:这行代码解析输入张量 x 的形状,其中 b 是批次大小(batch size),c 是通道数(channels),*spatial 是一个包含所有空间维度的元组(例如,对于二维图像,spatial 可能是 (height, width);对于三维数据,它可能是 (depth, height, width) 等)。
重塑输入张量:
x = x.reshape(b, c, -1):这行代码将输入张量 x 的空间维度展平。-1 表示自动计算该维度的大小,以便保持元素总数不变。这样做是为了将空间信息“折叠”到一个维度中,以便后续处理。
通过QKV投影和归一化:
qkv = self.qkv(self.norm(x)):首先,通过 self.norm 对重塑后的 x 进行归一化处理(可能是批归一化、层归一化等)。然后,将归一化后的结果传递给 self.qkv,这是一个线性层(或一组线性层),用于将输入投影到查询(Q)、键(K)和值(V)空间。注意,这里假设 self.qkv 的输出维度已经设计为 [b, 3
c’, T] 的形式,其中 c’ 是每个头的通道数,T 是重塑后的空间维度大小。
应用自注意力机制:
h = self.attention(qkv):将QKV投影的结果传递给自注意力模块 self.attention。这个模块将执行多头注意力机制,计算注意力权重,并使用这些权重对值向量进行加权求和,从而得到考虑了上下文信息的特征表示 h。
输出投影:
h = self.proj_out(h):将自注意力模块的输出 h 传递给 self.proj_out,这通常是一个线性层,用于将注意力模块的输出投影回原始通道数 c(或可能是一个不同的通道数,具体取决于设计)。
残差连接和形状恢复:
(x + h).reshape(b, c, *spatial):将原始输入 x 与注意力模块的输出 h 相加,实现残差连接。这有助于缓解深层网络中的梯度消失问题。然后,将相加的结果重新塑形为原始输入张量 x 的形状 [b, c, *spatial],以便与网络的其余部分兼容。
总之,这个方法通过结合自注意力机制和残差连接,有效地提取了输入数据的上下文信息,并保持了网络对深层特征的表示能力。
注意事项
在使用梯度检查点时,需要确保你的PyTorch版本支持这一功能,并且你理解它对内存使用和计算性能的影响。
在实际应用中,你可能需要根据你的具体任务和数据集调整注意力块的参数(如头数、每个头的通道数等)。
注释中提到的.half()调用可能是指将张量转换为半精度浮点数(fp16),这在某些情况下可以提高计算速度和减少内存使用,但也可能导致数值稳定性问题。如果你的模型或任务对数值精度要求较高,你可能需要谨慎使用。然而,在这个注意力块的实现中,并没有直接调用.half()方法。

二、代码中的QKVAttention 类和QKVAttentionLegacy类

class QKVAttentionLegacy(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)


class QKVAttention(nn.Module):
    """
    A module which performs QKV attention and splits in a different order.
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts",
            (q * scale).view(bs * self.n_heads, ch, length),
            (k * scale).view(bs * self.n_heads, ch, length),
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)

这两个类 QKVAttentionLegacy 和 QKVAttention 都实现了自注意力机制中的 QKV(查询、键、值)注意力模式,但它们在处理输入张量的方式上有所不同。下面是它们之间差异的详细解释:
QKVAttentionLegacy 类
输入处理:该类接受一个形状为 [N x (H * 3 * C) x T] 的 qkv 张量,其中 N 是批次大小,H 是头数(n_heads),C 是每个头的通道数,T 是序列长度。这个张量包含了查询(Q)、键(K)和值(V)的信息,它们被连续地排列在一起。
拆分处理:使用 .reshape() 方法将 qkv 张量重塑为 [bs * self.n_heads, ch * 3, length],然后通过 .split() 方法沿着第二个维度(通道维度)将其拆分为查询、键和值三部分。
注意力计算:使用 torch.einsum 进行矩阵乘法来计算注意力权重,并通过 softmax 进行归一化。最后,使用 torch.einsum 将归一化的注意力权重与值进行加权求和。
QKVAttention 类
输入处理:与 QKVAttentionLegacy 类似,该类也接受一个形状为 [N x (3 * H * C) x T] 的 qkv 张量,但这里的关键区别在于如何处理这个张量。
拆分处理:使用 .chunk() 方法沿着第一个维度(通常是通道维度)将 qkv 张量拆分为查询、键和值三部分。这与 QKVAttentionLegacy 使用 .reshape() 和 .split() 的方式不同。
注意力计算:注意力权重的计算过程与 QKVAttentionLegacy 相同,都是先通过 torch.einsum 进行矩阵乘法,然后通过 softmax 归一化,最后使用归一化的权重与值进行加权求和。
主要差异

输入张量拆分方式:QKVAttentionLegacy 使用 .reshape() 和 .split() 来拆分 qkv 张量,而 QKVAttention 使用 .chunk() 直接拆分。
拆分维度:QKVAttentionLegacy 沿着通道维度拆分(但先重塑为 [bs * self.n_heads, ch * 3, length]),而 QKVAttention 直接沿着原始 qkv 张量的第一个维度拆分。
注意事项
这两个类在注意力权重的计算和最终的输出形状上是相同的,主要区别在于处理输入张量的方式。在实际应用中,选择哪种方式取决于具体的场景和性能需求。
count_flops 静态方法在两个类中都是用来计算模型浮点运算次数的。

  • 8
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值