CNN in Transformer

Convolution Transformer

 是一种对Vision Transformer的改进,作者主要是从两个方面来改进了Trasformer。替换Position Embedding,和Linear Projection 为卷积操作。

目前代码并未开源,笔者自己实现的代码如下
#个人实现的conv_transformer
"""
author:brotherHappy
date:2021.4.23
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum


class MLP(nn.Module):
	def __init__(self, dim, drop_p=0.1, expand=3):
		super(MLP, self).__init__()
		self.fc1 = nn.Linear(dim, dim * expand)
		self.act1 = nn.GELU()
		self.fc2 = nn.Linear(dim * expand, dim)
		self.norm1 = nn.LayerNorm(dim)  # shape待定
		self.drop = nn.Dropout(p=drop_p)

	def forward(self, x):
		x = self.fc1(x)
		x = self.act1(x)
		x = self.drop(x)
		x = self.fc2(x)
		return self.drop(x)


class SepConv2d(nn.Module):
	def __init__(self, dim, out_dim, kernel_size, stride=1, padding=0, dilation=1):
		super(SepConv2d, self).__init__()
		self.conv1 = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=kernel_size, stride=stride,
							   padding=padding, dilation=dilation, groups=dim)
		self.norm = nn.BatchNorm2d(num_features=dim)
		self.conv2 = nn.Conv2d(in_channels=dim, out_channels=out_dim, kernel_size=1)

	def forward(self, x):
		return self.conv2(self.norm(self.conv1(x)))


class ConvTransEmbedding(nn.Module):
	"""
	in:(B,dim,H,W)
	out:(B,out_dim,H',W')
	"""
	resolution = 224

	def __init__(self, dim, out_dim, kernel_size, resolution, stride=4):
		super(ConvTransEmbedding, self).__init__()
		pad = ((resolution - 1) * stride + kernel_size - ConvTransEmbedding.resolution + 1) // 2
		self.embedding = SepConv2d(dim=dim, out_dim=out_dim, kernel_size=kernel_size, stride=stride, padding=pad)
		ConvTransEmbedding.resolution = resolution

	def forward(self, x):  # B, hw c
		b, hw, c = x.shape
		x = rearrange(x, 'b (h w) c -> b c h w', h=int(hw ** 0.5))
		x = self.embedding(x)
		x = rearrange(x, 'b c h w -> b (h w) c')
		return x


class Residual(nn.Module):
	def __init__(self, func):
		super(Residual, self).__init__()
		self.func = func

	def forward(self, x):
		return x + self.func(x)


class PreNorm(nn.Module):
	def __init__(self, dim, func):
		super(PreNorm, self).__init__()
		self.norm = nn.LayerNorm(normalized_shape=dim)
		self.func = func

	def forward(self, x):
		return self.func(self.norm(x))


class MHA(nn.Module):
	def __init__(self, dim, q_s, v_s, k_s, heads, resolution, kernel_size=3):
		super(MHA, self).__init__()
		self.heads = heads
		self.scale = (dim / heads) ** -0.5
		self.resolution = resolution
		# pad = (kernel_size - q_s) // 2
		pad_v = ((resolution - 1) * v_s + kernel_size - resolution + 1) // 2
		pad_qk = ((resolution - 1) * q_s + kernel_size - resolution + 1) // 2
		self.conv_q = SepConv2d(dim=dim, out_dim=dim, kernel_size=kernel_size, stride=q_s,
								padding=pad_qk)
		self.conv_k = SepConv2d(dim=dim, out_dim=dim, kernel_size=kernel_size, stride=k_s,
								padding=pad_qk)
		self.conv_v = SepConv2d(dim=dim, out_dim=dim, kernel_size=kernel_size, stride=v_s,
								padding=pad_v)

	def forward(self, x):
		b, l, c = x.shape
		x = rearrange(x, 'b (h w) c -> b c h w', h=self.resolution)
		q = rearrange(self.conv_q(x), 'b (h n) l w -> (b h) (l w) n', h=self.heads)  # b d h w -> B L C
		k = rearrange(self.conv_k(x), 'b (h n) l w -> (b h) (l w) n', h=self.heads)  # b d h w
		v = rearrange(self.conv_v(x), 'b (h n) l w -> (b h) (l w) n', h=self.heads)  # b d h w
		dots = einsum('b l c,b w c -> b l w', q, k) * self.scale  # b l l
		attn = dots.softmax(-1)

		out = einsum('b m n,b n c -> b m c', attn, v)  # B l n
		out = rearrange(out, pattern='(b h) l n -> b l (h n)', b=10)
		return out


class Stage(nn.Module):
	def __init__(self, heads, expand, depth, dim, out_dim, resolution, stride, q_s, k_s, v_s, kernel_size,
				 drop_p=0.1):
		super(Stage, self).__init__()
		self.conv_token_embedding = ConvTransEmbedding(dim=dim, out_dim=out_dim, kernel_size=kernel_size,
													   resolution=resolution, stride=stride)
		blocks = nn.ModuleList()
		for i in range(depth):
			blocks.append(Residual(
				MHA(dim=out_dim, q_s=q_s, v_s=v_s, k_s=k_s, heads=heads, resolution=resolution)))
			blocks.append(
				Residual(PreNorm(dim=out_dim,
								 func=PreNorm(dim=out_dim, func=MLP(dim=out_dim, drop_p=drop_p, expand=expand)))))
		self.blocks = nn.Sequential(*blocks)

	def forward(self, x):
		x = self.conv_token_embedding(x)
		return self.blocks(x)


class ConvTransformer(nn.Module):
	def __init__(self, stride=(4, 2, 2), kernel_size=(7, 3, 3), dim=(3, 64, 192, 384), expand=(4, 4, 4),
				 num_classes=1000, depth=(1, 2, 10), heads=(1, 3, 6), resolution=(56, 28, 14), q_s=(1, 1, 1),
				 v_s=(1, 1, 1), k_s=(1, 1, 1), drop_p=0.1):
		super(ConvTransformer, self).__init__()
		layers = nn.ModuleList()
		for i in range(3):
			layers.append(Stage(heads=heads[i], expand=expand[i], depth=depth[i], dim=dim[i], out_dim=dim[i + 1],
								resolution=resolution[i], stride=stride[i], q_s=q_s[i], v_s=v_s[i], k_s=k_s[i],
								kernel_size=kernel_size[i], drop_p=drop_p))
		self.layers = nn.Sequential(*layers)
		self.out_class = nn.Sequential(
			nn.Dropout(drop_p)
			, nn.Linear(in_features=resolution[2] ** 2 * dim[3], out_features=num_classes)
		)

	def forward(self, x):
		x = rearrange(x, 'b c h w -> b (h w) c')
		x = self.layers(x)
		x = rearrange(x, 'b l c -> b (l c)')
		x = self.out_class(x)
		return x

#调用下面的代码即可
model = ConvTransformer()
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值