Convolution Transformer
是一种对Vision Transformer的改进,作者主要是从两个方面来改进了Trasformer。替换Position Embedding,和Linear Projection 为卷积操作。
目前代码并未开源,笔者自己实现的代码如下
#个人实现的conv_transformer
"""
author:brotherHappy
date:2021.4.23
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
class MLP(nn.Module):
def __init__(self, dim, drop_p=0.1, expand=3):
super(MLP, self).__init__()
self.fc1 = nn.Linear(dim, dim * expand)
self.act1 = nn.GELU()
self.fc2 = nn.Linear(dim * expand, dim)
self.norm1 = nn.LayerNorm(dim) # shape待定
self.drop = nn.Dropout(p=drop_p)
def forward(self, x):
x = self.fc1(x)
x = self.act1(x)
x = self.drop(x)
x = self.fc2(x)
return self.drop(x)
class SepConv2d(nn.Module):
def __init__(self, dim, out_dim, kernel_size, stride=1, padding=0, dilation=1):
super(SepConv2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=dim)
self.norm = nn.BatchNorm2d(num_features=dim)
self.conv2 = nn.Conv2d(in_channels=dim, out_channels=out_dim, kernel_size=1)
def forward(self, x):
return self.conv2(self.norm(self.conv1(x)))
class ConvTransEmbedding(nn.Module):
"""
in:(B,dim,H,W)
out:(B,out_dim,H',W')
"""
resolution = 224
def __init__(self, dim, out_dim, kernel_size, resolution, stride=4):
super(ConvTransEmbedding, self).__init__()
pad = ((resolution - 1) * stride + kernel_size - ConvTransEmbedding.resolution + 1) // 2
self.embedding = SepConv2d(dim=dim, out_dim=out_dim, kernel_size=kernel_size, stride=stride, padding=pad)
ConvTransEmbedding.resolution = resolution
def forward(self, x): # B, hw c
b, hw, c = x.shape
x = rearrange(x, 'b (h w) c -> b c h w', h=int(hw ** 0.5))
x = self.embedding(x)
x = rearrange(x, 'b c h w -> b (h w) c')
return x
class Residual(nn.Module):
def __init__(self, func):
super(Residual, self).__init__()
self.func = func
def forward(self, x):
return x + self.func(x)
class PreNorm(nn.Module):
def __init__(self, dim, func):
super(PreNorm, self).__init__()
self.norm = nn.LayerNorm(normalized_shape=dim)
self.func = func
def forward(self, x):
return self.func(self.norm(x))
class MHA(nn.Module):
def __init__(self, dim, q_s, v_s, k_s, heads, resolution, kernel_size=3):
super(MHA, self).__init__()
self.heads = heads
self.scale = (dim / heads) ** -0.5
self.resolution = resolution
# pad = (kernel_size - q_s) // 2
pad_v = ((resolution - 1) * v_s + kernel_size - resolution + 1) // 2
pad_qk = ((resolution - 1) * q_s + kernel_size - resolution + 1) // 2
self.conv_q = SepConv2d(dim=dim, out_dim=dim, kernel_size=kernel_size, stride=q_s,
padding=pad_qk)
self.conv_k = SepConv2d(dim=dim, out_dim=dim, kernel_size=kernel_size, stride=k_s,
padding=pad_qk)
self.conv_v = SepConv2d(dim=dim, out_dim=dim, kernel_size=kernel_size, stride=v_s,
padding=pad_v)
def forward(self, x):
b, l, c = x.shape
x = rearrange(x, 'b (h w) c -> b c h w', h=self.resolution)
q = rearrange(self.conv_q(x), 'b (h n) l w -> (b h) (l w) n', h=self.heads) # b d h w -> B L C
k = rearrange(self.conv_k(x), 'b (h n) l w -> (b h) (l w) n', h=self.heads) # b d h w
v = rearrange(self.conv_v(x), 'b (h n) l w -> (b h) (l w) n', h=self.heads) # b d h w
dots = einsum('b l c,b w c -> b l w', q, k) * self.scale # b l l
attn = dots.softmax(-1)
out = einsum('b m n,b n c -> b m c', attn, v) # B l n
out = rearrange(out, pattern='(b h) l n -> b l (h n)', b=10)
return out
class Stage(nn.Module):
def __init__(self, heads, expand, depth, dim, out_dim, resolution, stride, q_s, k_s, v_s, kernel_size,
drop_p=0.1):
super(Stage, self).__init__()
self.conv_token_embedding = ConvTransEmbedding(dim=dim, out_dim=out_dim, kernel_size=kernel_size,
resolution=resolution, stride=stride)
blocks = nn.ModuleList()
for i in range(depth):
blocks.append(Residual(
MHA(dim=out_dim, q_s=q_s, v_s=v_s, k_s=k_s, heads=heads, resolution=resolution)))
blocks.append(
Residual(PreNorm(dim=out_dim,
func=PreNorm(dim=out_dim, func=MLP(dim=out_dim, drop_p=drop_p, expand=expand)))))
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.conv_token_embedding(x)
return self.blocks(x)
class ConvTransformer(nn.Module):
def __init__(self, stride=(4, 2, 2), kernel_size=(7, 3, 3), dim=(3, 64, 192, 384), expand=(4, 4, 4),
num_classes=1000, depth=(1, 2, 10), heads=(1, 3, 6), resolution=(56, 28, 14), q_s=(1, 1, 1),
v_s=(1, 1, 1), k_s=(1, 1, 1), drop_p=0.1):
super(ConvTransformer, self).__init__()
layers = nn.ModuleList()
for i in range(3):
layers.append(Stage(heads=heads[i], expand=expand[i], depth=depth[i], dim=dim[i], out_dim=dim[i + 1],
resolution=resolution[i], stride=stride[i], q_s=q_s[i], v_s=v_s[i], k_s=k_s[i],
kernel_size=kernel_size[i], drop_p=drop_p))
self.layers = nn.Sequential(*layers)
self.out_class = nn.Sequential(
nn.Dropout(drop_p)
, nn.Linear(in_features=resolution[2] ** 2 * dim[3], out_features=num_classes)
)
def forward(self, x):
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.layers(x)
x = rearrange(x, 'b l c -> b (l c)')
x = self.out_class(x)
return x
#调用下面的代码即可
model = ConvTransformer()