edgeVIT
原文:EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers
代码
CNN用了PVT的典型架构
代码:
参考博客
import torch
import torch.nn as nn
# edgevits的配置信息
edgevit_configs = {
'XXS': {
'channels': (36, 72, 144, 288),
'blocks': (1, 1, 3, 2),
'heads': (1, 2, 4, 8)
}
,
'XS': {
'channels': (48, 96, 240, 384),
'blocks': (1, 1, 2, 2),
'heads': (1, 2, 4, 8)
}
,
'S': {
'channels': (48, 96, 240, 384),
'blocks': (1, 2, 3, 2),
'heads': (1, 2, 4, 8)
}
}
HYPERPARAMETERS = {
'r': (4, 2, 2, 1)
}
class Residual(nn.Module):
"""
残差网络
"""
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
return x + self.module(x)
class ConditionalPositionalEncoding(nn.Module):
"""
"""
def __init__(self, channels):
super(ConditionalPositionalEncoding, self).__init__()
self.conditional_positional_encoding = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels,
bias=False)
def forward