Transformer详细结构介绍

重要思想

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],非特征图。

token指的是有表征能力的向量。

代码

(可能有误)注重看思想

"""A transformer for computer vision models.
"""

import functools
import math
from typing import Optional

import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import functional as F


class MHAttentionMap(nn.Module):
    """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""

    def __init__(self, query_dim, hidden_dim, num_heads=1, dropout=0.0, bias=True): # 100, 256, 8
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)

        nn.init.zeros_(self.k_linear.bias)
        nn.init.zeros_(self.q_linear.bias)
        nn.init.xavier_uniform_(self.k_linear.weight)
        nn.init.xavier_uniform_(self.q_linear.weight)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k):
        q = self.q_linear(q)
        k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
        qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)

        weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
        weights = self.dropout(weights)
        return weights


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x):
        mask = torch.zeros((x.shape[0],) + x.shape[2:]).byte()
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None].to(device=x.device) / dim_t
        pos_y = y_embed[:, :, :, None].to(device=x.device) / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(),
                             pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(),
                             pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class Transformer(nn.Module):
    """ A transformer module used for computer vision tasks.
    First project the feature map into tokens as the input of the transformer.

    Parameters:
        num_tokens: number of tokenizer (l)
        dim_tokens: dimension of tokenizer (ct)
        num_channels: channel number of the feature map (c)
        num_heads: for multi-head self-attention. embed_dim must be divisible by num_heads
        num_groups: group conv as linear transformation for q, k, v

        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
    """

    def __init__(self, num_tokens, num_channels, output_channels=None, num_queries=None, num_heads=1, num_groups=1,
                 down_sample=(8, 8), position_encoding='points', use_decoder=True,
                 positional_decoder=False, attention_for_seg=False, downsampling=False):
        super().__init__()

        self.num_tokens = num_tokens
        self.dim_tokens = down_sample[0] * down_sample[1]
        self.num_channels = num_channels
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.down_sample = down_sample
        self.position_encoding = position_encoding
        self.use_decoder = use_decoder
        self.positional_decoder = positional_decoder
        self.attention_for_seg = attention_for_seg
        self.downsampling = downsampling

        if use_decoder and num_queries is None:
            num_queries = self.num_tokens

        if output_channels is None:
            self.output_channels = num_channels
        else:
            self.output_channels = output_channels

        # c -> l, get 2d attention score, It can be seen as a convolutional filter that divides
        # the feature map into various regions that corresponds to different semantic concepts.
        if position_encoding == 'points':
            self.input_proj = nn.Conv2d(self.num_channels + 2, self.num_tokens, kernel_size=1, bias=False)
        else:
            self.input_proj = nn.Conv2d(self.num_channels, self.num_tokens, kernel_size=1, bias=False)
        self.input_norm = nn.BatchNorm2d(self.num_tokens)

        # Transformer Encoder
        self.encoder = TransformerEncoderLayer(self.dim_tokens, nhead=self.num_heads)
        if use_decoder:
            self.query_embed = nn.Embedding(num_queries, self.dim_tokens)
            self.decoder = TransformerDecoderLayer(self.dim_tokens, nhead=self.num_heads)
            if positional_decoder and position_encoding == 'points':
                self.num_tokens += 2

        self.reverse_proj = nn.Conv2d(self.num_tokens, self.output_channels, kernel_size=1, bias=False)
        self.reverse_norm = nn.BatchNorm2d(self.output_channels)

        if attention_for_seg:
            self.attention = MHAttentionMap(num_queries, self.dim_tokens)
            self.att_proj = nn.Conv2d(self.dim_tokens, num_queries, kernel_size=1, bias=False)


        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, input: Tensor):
        batch_size, _, height, width = input.shape
        feature = input
        if self.down_sample is not None:
            feature = resize(feature,
                             self.down_sample,
                             mode='bilinear',
                             align_corners=False)

        if self.position_encoding == 'points':
            position_embedding = torch.from_numpy(get_points_single(feature.shape[-2:])) \
                .unsqueeze(0).repeat(batch_size, 1, 1, 1).to(feature.device)
            feature = torch.cat([feature, position_embedding], dim=1)

        feature = self.input_proj(feature)
        feature = self.input_norm(feature)

        feature = feature.flatten(2).permute(1, 0, 2)  # l, b, hw
        memory = self.encoder(feature)
        if self.use_decoder:
            query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)
            if self.positional_decoder and self.position_encoding == 'points':
                feature = self.decoder(query_embed, memory, position_embedding)
            else:
                feature = self.decoder(query_embed, memory)
        else:
            feature = memory

        feature = feature.view((self.num_tokens, batch_size) + self.down_sample).permute(1, 0, 2, 3)

        if self.attention_for_seg:
            hs = feature.flatten(2).permute(0, 2, 1)
            memory = memory.view((self.num_tokens, batch_size) + self.down_sample).permute(1, 0, 2, 3)
            att = self.attention(hs, memory).squeeze(2)
            att = self.att_proj(att)
            feature = feature * att

        feature = self.reverse_proj(feature)
        feature = self.reverse_norm(feature)

        if self.downsampling:
            height = height // 2
            width = width // 2

        feature = resize(feature,
                         (height, width),
                         mode='bilinear',
                         align_corners=False)
        if self.downsampling or self.num_channels != self.output_channels:
            return feature
        else:
            return feature + input



class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, dim_feedforward=None, nhead=1, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()

        if dim_feedforward is None:
            dim_feedforward = d_model

        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def forward_post(self, src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None):
        q = k = src
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)  # add self-att
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))  # matrix multiply
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = src2
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask)
        return self.forward_post(src, src_mask, src_key_padding_mask)


class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, dim_feedforward=None, nhead=1, dropout=0.1,
                 activation="relu"):
        super().__init__()

        if dim_feedforward is None:
            dim_feedforward = d_model

        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, query, memory, position_embedding=None):
        if position_embedding is not None:
            query = torch.cat([query, position_embedding.flatten(2).permute(1, 0, 2)], dim=0)
            memory = torch.cat([memory, position_embedding.flatten(2).permute(1, 0, 2)], dim=0)

        tgt = self.multihead_attn(query=query,
                                  key=memory,
                                  value=memory)[0]
        tgt = memory + self.dropout1(tgt)
        tgt = self.norm1(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        return tgt


def get_points_single(size, stride=1, dtype=np.float32):
    """The vanilla version of positional encoding (2 channels)."""

    height, width = size
    x_range = np.arange(0, width * stride, stride, dtype=dtype)
    y_range = np.arange(0, height * stride, stride, dtype=dtype)
    y_channel, x_channel = np.meshgrid(y_range, x_range)
    points = (np.stack((x_channel, y_channel)) + stride // 2).transpose((1, 2, 0))
    points = (points - points.mean(axis=(0, 1))) / points.std(axis=(0, 1))
    return points.transpose((2, 1, 0))


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


def resize(input,
           size=None,
           scale_factor=None,
           mode='nearest',
           align_corners=None):
    return F.interpolate(input, size, scale_factor, mode, align_corners)

位置编码的理解PositionEmbeddingSine

如果是点位置编码

输入 b c w h

那么只需记录 w,h 的位置坐标,用坐标点(x,y)表示

所以加入位置编码后的尺度变化为:

b (c+2) w,h,

这多于的两个(w,h)张量值如下

如果是余弦编码

则可能把0,1,2,3,换成用余弦代表的值,以达到能表示更多值。

b (c+dim) w,h,

这多于的dim(w,h)张量值如下

class PositionEmbeddingSine(nn.Module):
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x):
        mask = torch.zeros((x.shape[0],) + x.shape[2:]).byte()
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None].to(device=x.device) / dim_t
        pos_y = y_embed[:, :, :, None].to(device=x.device) / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(),
                             pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(),
                             pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

多头注意力的理解MHAttentionMap

代码

class MHAttentionMap(nn.Module):
    """2D attention module, which only returns the attention softmax (no multiplication by value)."""
    def __init__(self, query_dim, hidden_dim, num_heads=1, dropout=0.0, bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)
​
        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
​
        nn.init.zeros_(self.k_linear.bias)
        nn.init.zeros_(self.q_linear.bias)
        nn.init.xavier_uniform_(self.k_linear.weight)
        nn.init.xavier_uniform_(self.q_linear.weight)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
​
    def forward(self, q, k):
        q = self.q_linear(q)
        k = self.k_linear(k)
        qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
​
        weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
        weights = self.dropout(weights)
        return weights

将Q,K,按照头的数量进行分割,分割成不同的batch

然后对他们采用QKV自注意力机制,最后返回权重

编码器

查看是否使用位置编码,如有有,则加上位置编码信息

if self.position_encoding == 'points':
    position_embedding = torch.from_numpy(get_points_single(feature.shape[-2:])) \
        .unsqueeze(0).repeat(batch_size, 1, 1, 1).to(feature.device)
    feature = torch.cat([feature, position_embedding], dim=1)  # 通道数变为 self.num_channels + 2

在编码器之前通常进行映射或者下采样以减少计算量

if self.down_sample is not None:
    feature = resize(feature,
                     self.down_sample,
                     mode='bilinear',
                     align_corners=False)
feature = self.input_proj(feature)
feature = self.input_norm(feature)

把hw进行展平,将输入 b,c,w,h 变为b,c,hw,输入到编码器中

feature = feature.flatten(2).permute(1, 0, 2)  # l, b, hw
memory = self.encoder(feature)

编码器重要组件

1.多头自注意力机制

self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)

2.线性层

self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)

3.归一化层

self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)

4.dropout层

self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

定义这些组合的顺序

if self.normalize_before:
    return self.forward_pre(src, src_mask, src_key_padding_mask)
return self.forward_post(src, src_mask, src_key_padding_mask)
def forward_post(self, src,
                 src_mask: torch.Tensor = None,
                 src_key_padding_mask: torch.Tensor = None):
  • 如果 normalize_beforeFalse,则先进行自注意力操作,然后进行归一化和前馈网络操作。

q = k = src
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                      key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)  # add self-att
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))  # matrix multiply
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
                src_mask: torch.Tensor = None,
                src_key_padding_mask: torch.Tensor = None):
  • 如果 normalize_beforeTrue,则先进行归一化,然后进行自注意力操作和前馈网络操作。

src2 = self.norm1(src)
q = k = src2
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                      key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src

解码器

解码器重要组件

多头注意力层。

self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
  • 定义前馈网络的两层线性变换。

self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
  • 定义两个归一化层和两个 Dropout 层。

self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
  • 定义激活函数和归一化顺序。

self.activation = F.relu
self.normalize_before = normalize_before
  • 输入参数 query 是查询张量,memory 是编码器的输出,position_embedding 是位置编码(可选)。

def forward(self, query, memory, position_embedding=None):
 
    query = torch.cat([query, position_embedding.flatten(2).permute(1, 0, 2)], dim=0)
    memory = torch.cat([memory, position_embedding.flatten(2).permute(1, 0, 2)], dim=0)
  • 如果提供了位置编码,则将其与查询和记忆张量拼接。

  • 计算多头注意力,将结果与记忆张量相加,然后进行归一化和前馈网络操作。

tgt = self.multihead_attn(query=query,
                          key=memory,
                          value=memory)[0]
tgt = memory + self.dropout1(tgt)
tgt = self.norm1(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
return tgt

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值