import torch
from model.backbone import resnet
import numpy as np
from model.FPN_neck import *
class conv_bn_relu(torch.nn.Module):
def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=False):
super(conv_bn_relu,self).__init__()
self.conv = torch.nn.Conv2d(in_channels,out_channels, kernel_size,
stride = stride, padding = padding, dilation = dilation,bias = bias)
self.bn = torch.nn.BatchNorm2d(out_channels)
self.relu = torch.nn.ReLU()
def forward(self,x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class parsingNet(torch.nn.Module):
def __init__(self, size=(288, 800), pretrained=True, backbone='50', cls_dim=(37, 10, 4), use_aux=False):
super(parsingNet, self).__init__()
self.size = size
self.w = size[0]
self.h = size[1]
self.cls_dim = cls_dim # (num_gridding, num_cls_per_lane, num_of_lanes)
# num_cls_per_lane is the number of row anchors
self.use_aux = use_aux
self.total_dim = np.prod(cls_dim)
# input : nchw,
# output: (w+1) * sample_rows * 4
self.model = resnet(backbone, pretrained=pretrained)
if self.use_aux:
self.aux_header2 = torch.nn.Sequential(
conv_bn_relu(128, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1),
conv_bn_relu(128,128,3,padding=1),
conv_bn_relu(128,128,3,padding=1),
conv_bn_relu(128,128,3,padding=1),
)
self.aux_header3 = torch.nn.Sequential(
conv_bn_relu(256, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(1024, 128, kernel_size=3, stride=1, padding=1),
conv_bn_relu(128,128,3,padding=1),
conv_bn_relu(128,128,3,padding=1),
)
self.aux_header4 = torch.nn.Sequential(
conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(2048, 128, kernel_size=3, stride=1, padding=1),
conv_bn_relu(128,128,3,padding=1),
)
self.aux_combine = torch.nn.Sequential(
conv_bn_relu(384, 256, 3,padding=2,dilation=2),
conv_bn_relu(256, 128, 3,padding=2,dilation=2),
conv_bn_relu(128, 128, 3,padding=2,dilation=2),
conv_bn_relu(128, 128, 3,padding=4,dilation=4),
torch.nn.Conv2d(128, cls_dim[-1] + 1,1)
# output : n, num_of_lanes+1, h, w
)
initialize_weights(self.aux_header2,self.aux_header3,self.aux_header4,self.aux_combine)
self.cls = torch.nn.Sequential(
torch.nn.Linear(1800, 2048),
torch.nn.ReLU(),
torch.nn.Linear(2048, self.total_dim),
)
self.FPN = FPN_Decoder(C3_size=128, C4_size=256, C5_size=512)
self.pool = torch.nn.Conv2d(64, 8 ,1) if backbone in ['34','18'] else torch.nn.Conv2d(2048, 8, 1)
self.pool_2 = torch.nn.AvgPool2d(4, 4)
# 1/32,2048 channel
# 288,800 -> 9,40,2048
# (w+1) * sample_rows * 4
# 37 * 10 * 4
initialize_weights(self.cls)
initialize_weights(self.FPN)
def forward(self, x):
# n c h w - > n 2048 sh sw
# -> n 2048
x2,x3,fea = self.model(x)
if self.use_aux:
x2z = self.aux_header2(x2)
x3z = self.aux_header3(x3)
x3z = torch.nn.functional.interpolate(x3z,scale_factor = 2,mode='bilinear')
x4z = self.aux_header4(fea)
x4z = torch.nn.functional.interpolate(x4z,scale_factor = 4,mode='bilinear')
aux_seg = torch.cat([x2z,x3z,x4z],dim=1)
aux_seg = self.aux_combine(aux_seg)
else:
aux_seg = None
fea_list = self.FPN(x2, x3, fea)#现在输出的fea是(64,36,100)参数量230,400 期望参数量为115,200
fea = fea_list[0]
fea = self.pool_2(fea)
fea = self.pool(fea).view(-1, 1800)#fea参数量为应当为(8,1800)
group_cls = self.cls(fea).view(-1, *self.cls_dim)#cls_dim分类维度size(201,18,4)
if self.use_aux:
return group_cls, aux_seg
return group_cls
def initialize_weights(*models):
for model in models:
real_init_weights(model)
def real_init_weights(m):
if isinstance(m, list):
for mini_m in m:
real_init_weights(mini_m)
else:
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.Linear):
m.weight.data.normal_(0.0, std=0.01)
elif isinstance(m, torch.nn.BatchNorm2d):
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m,torch.nn.Module):
for mini_m in m.children():
real_init_weights(mini_m)
else:
print('unkonwn module', m)
import torch
#from model.transformer_encode import *
from model.mateformer_pool import *
from model.mateformer_attention import *
class FPN_Decoder(torch.nn.Module):#经典FPN结构
def __init__(self, C3_size, C4_size, C5_size, feature_size=64):
super(FPN_Decoder, self).__init__()
# upsample C5 to get P5 from the FPN paper
self.P5_1 = torch.nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
self.P5_upsampled = torch.nn.Upsample(scale_factor=2, mode='nearest')
self.P5_2 = torch.nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
# add P5 elementwise to C4
self.P4_1 = torch.nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
self.P4_upsampled = torch.nn.Upsample(scale_factor=2, mode='nearest')
self.P4_2 = torch.nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
# add P4 elementwise to C3
self.P3_1 = torch.nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
self.P3_upsampled = torch.nn.Upsample(scale_factor=2, mode='nearest')
self.P3_2 = torch.nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
#self.transformer_encode = TransConvEncoderModule(512, [512,128], [128,128], [1,1], [4,4], pos_shape=[8,9,25])
#self.mateformer_encode = Mateformer(patch_size1=3, strid1=1, padding1=1, in_chans1=512, out_chans=512) #这个参数有待确定
self.mateformer_encode = metaformer_self()
def forward(self, C3, C4, C5):
P5_x = self.mateformer_encode(C5)
P5_x = self.P5_1(P5_x)
P5_upsampled_x = self.P5_upsampled(P5_x)
P5_x = self.P5_2(P5_x)
P4_x = self.P4_1(C4)
P4_x = P5_upsampled_x + P4_x
P4_upsampled_x = self.P4_upsampled(P4_x)
P4_x = self.P4_2(P4_x)
P3_x = self.P3_1(C3)
P3_x = P3_x + P4_upsampled_x
P3_x = self.P3_2(P3_x)
return [P3_x, P4_x, P5_x]
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MetaFormer implementation with hybrid stages
"""
from typing import Sequence
from functools import partial, reduce
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from model.poolformer import PatchEmbed, LayerNormChannel, GroupNorm, Mlp
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head',
**kwargs
}
class AddPositionEmb(nn.Module):
"""Module to add position embedding to input features
"""
def __init__(
self, dim=384, spatial_shape=[14, 14],
):
super().__init__()
if isinstance(spatial_shape, int):
spatial_shape = [spatial_shape]
assert isinstance(spatial_shape, Sequence), \
f'"spatial_shape" must by a sequence or int, ' \
f'get {type(spatial_shape)} instead.'
if len(spatial_shape) == 1:
embed_shape = list(spatial_shape) + [dim]
else:
embed_shape = [dim] + list(spatial_shape)
self.pos_embed = nn.Parameter(torch.zeros(1, *embed_shape))
def forward(self, x):
return x+self.pos_embed
class Pooling(nn.Module):
"""
Implementation of pooling for PoolFormer
--pool_size: pooling size
"""
def __init__(self, pool_size=3, **kwargs):
super().__init__()
self.pool = nn.AvgPool2d(
pool_size, stride=1, padding=pool_size//2, count_include_pad=False)
def forward(self, x):
return self.pool(x) - x
class Attention(nn.Module):
"""Attention module that can take tensor with [B, N, C] or [B, C, H, W] as input.
Modified from:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
def __init__(self, dim, head_dim=32, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % head_dim == 0, 'dim should be divisible by head_dim'
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
shape = x.shape
if len(shape) == 4:
B, C, H, W = shape
N = H * W
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# trick here to make q@k.t more stable
attn = (q * self.scale) @ k.transpose(-2, -1)
# attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
if len(shape) == 4:
x = x.transpose(-2, -1).reshape(B, C, H, W)
return x
class SpatialFc(nn.Module):
"""SpatialFc module that take features with shape of (B,C,*) as input.
"""
def __init__(
self, spatial_shape=[14, 14], **kwargs,
):
super().__init__()
if isinstance(spatial_shape, int):
spatial_shape = [spatial_shape]
assert isinstance(spatial_shape, Sequence), \
f'"spatial_shape" must by a sequence or int, ' \
f'get {type(spatial_shape)} instead.'
N = reduce(lambda x, y: x * y, spatial_shape)
self.fc = nn.Linear(N, N, bias=False)
def forward(self, x):
# input shape like [B, C, H, W]
shape = x.shape
x = torch.flatten(x, start_dim=2) # [B, C, H*W]
x = self.fc(x) # [B, C, H*W]
x = x.reshape(*shape) # [B, C, H, W]
return x
class MetaFormerBlock(nn.Module):
"""
Implementation of one MetaFormer block.
--dim: embedding dim
--token_mixer: token mixer module
--mlp_ratio: mlp expansion ratio
--act_layer: activation
--norm_layer: normalization
--drop: dropout rate
--drop path: Stochastic Depth,
refer to https://arxiv.org/abs/1603.09382
--use_layer_scale, --layer_scale_init_value: LayerScale,
refer to https://arxiv.org/abs/2103.17239
"""
def __init__(self, dim,
token_mixer=nn.Identity,
mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=LayerNormChannel,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = norm_layer(dim)
self.token_mixer = token_mixer(dim=dim)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def basic_blocks(dim, index, layers, token_mixer=nn.Identity,
mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=LayerNormChannel,
drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
"""
generate PoolFormer blocks for a stage
return: PoolFormer blocks
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (
block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(MetaFormerBlock(
dim, token_mixer=token_mixer, mlp_ratio=mlp_ratio,
act_layer=act_layer, norm_layer=norm_layer,
drop=drop_rate, drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
blocks = nn.Sequential(*blocks)
return blocks
class MetaFormer(nn.Module):
"""
MetaFormer, the main class of our model
--layers: [x,x,x,x], number of blocks for the 4 stages
--embed_dims, --mlp_ratios: the embedding dims and mlp ratios for the 4 stages
--token_mixers: token mixers of different stages
--norm_layer, --act_layer: define the types of normalization and activation
--num_classes: number of classes for the image classification
--in_patch_size, --in_stride, --in_pad: specify the patch embedding
for the input image
--down_patch_size --down_stride --down_pad:
specify the downsample (patch embed.)
--add_pos_embs: position embedding modules of different stages
"""
def __init__(self, layers, embed_dims=None,
token_mixers=None, mlp_ratios=None,
norm_layer=LayerNormChannel, act_layer=nn.GELU,
num_classes=1000,
in_patch_size=7, in_stride=4, in_pad=2,
downsamples=None, down_patch_size=3, down_stride=2, down_pad=1,
add_pos_embs=None,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.patch_embed = PatchEmbed(
patch_size=3, stride=1, padding=1,
in_chans=512, embed_dim=embed_dims[0])
if add_pos_embs is None:
add_pos_embs = [None] * len(layers)
if token_mixers is None:
token_mixers = [nn.Identity] * len(layers)
# set the main block in network
network = []
for i in range(len(layers)):
if add_pos_embs[i] is not None:
network.append(add_pos_embs[i](embed_dims[i]))
stage = basic_blocks(embed_dims[i], i, layers,
token_mixer=token_mixers[i], mlp_ratio=mlp_ratios[i],
act_layer=act_layer, norm_layer=norm_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value)
network.append(stage)
if i >= len(layers) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
# downsampling between two stages
network.append(
PatchEmbed(
patch_size=down_patch_size, stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
)
)
self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1])
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
self.apply(self.cls_init_weights)
# init for classification
def cls_init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_embeddings(self, x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
for idx, block in enumerate(self.network):
x = block(x)
return x
def forward(self, x):
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
# x = self.norm(x)
# # for image classification
# cls_out = self.head(x.mean([-2, -1]))
return x
model_urls = {
"metaformer_id_s12": "https://github.com/sail-sg/poolformer/releases/download/v1.0/metaformer_id_s12.pth.tar",
"metaformer_pppa_s12_224": "https://github.com/sail-sg/poolformer/releases/download/v1.0/metaformer_pppa_s12_224.pth.tar",
"metaformer_ppaa_s12_224": "https://github.com/sail-sg/poolformer/releases/download/v1.0/metaformer_ppaa_s12_224.pth.tar",
"metaformer_pppf_s12_224": "https://github.com/sail-sg/poolformer/releases/download/v1.0/metaformer_pppf_s12_224.pth.tar",
"metaformer_ppff_s12_224": "https://github.com/sail-sg/poolformer/releases/download/v1.0/metaformer_ppff_s12_224.pth.tar",
}
@register_model
def metaformer_id_s12(pretrained=False, **kwargs):
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
token_mixers = [nn.Identity] * len(layers)
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = MetaFormer(
layers, embed_dims=embed_dims,
token_mixers=token_mixers,
mlp_ratios=mlp_ratios,
norm_layer=GroupNorm,
downsamples=downsamples,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
if pretrained:
url = model_urls['metaformer_id_s12']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def metaformer_pppa_s12_224(pretrained=False, **kwargs):
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
add_pos_embs = [None, None, None,
partial(AddPositionEmb, spatial_shape=[7, 7])]
token_mixers = [Pooling, Pooling, Pooling, Attention]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = MetaFormer(
layers, embed_dims=embed_dims,
token_mixers=token_mixers,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
add_pos_embs=add_pos_embs,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
url = model_urls['metaformer_pppa_s12_224']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def metaformer_ppaa_s12_224(pretrained=False, **kwargs):
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
add_pos_embs = [None, None,
partial(AddPositionEmb, spatial_shape=[14, 14]), None]
token_mixers = [Pooling, Pooling, Attention, Attention]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = MetaFormer(
layers, embed_dims=embed_dims,
token_mixers=token_mixers,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
add_pos_embs=add_pos_embs,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
url = model_urls['metaformer_ppaa_s12_224']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def metaformer_pppf_s12_224(pretrained=False, **kwargs):
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
token_mixers = [Pooling, Pooling, Pooling,
partial(SpatialFc, spatial_shape=[7, 7]),
]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = MetaFormer(
layers, embed_dims=embed_dims,
token_mixers=token_mixers,
mlp_ratios=mlp_ratios,
norm_layer=GroupNorm,
downsamples=downsamples,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
if pretrained:
url = model_urls['metaformer_pppf_s12_224']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def metaformer_ppff_s12_224(pretrained=False, **kwargs):
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
token_mixers = [Pooling, Pooling,
partial(SpatialFc, spatial_shape=[14, 14]),
partial(SpatialFc, spatial_shape=[7, 7]),
]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = MetaFormer(
layers, embed_dims=embed_dims,
token_mixers=token_mixers,
mlp_ratios=mlp_ratios,
norm_layer=GroupNorm,
downsamples=downsamples,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
url = model_urls['metaformer_ppff_s12_224']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def metaformer_self(pretrained=False, **kwargs):
layers = [2]
embed_dims = [512]
add_pos_embs = [partial(AddPositionEmb, spatial_shape=[9, 25])]
token_mixers = [Attention]
mlp_ratios = [4]
downsamples = [True]
model = MetaFormer(
layers, embed_dims=embed_dims,
token_mixers=token_mixers,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
add_pos_embs=add_pos_embs,
**kwargs)
model.default_cfg = _cfg()
return model
if __name__ == '__main__':
model = metaformer_self()
model = model
print(model)
x = torch.randn(8, 512, 9, 25)
out = model(x)
print(out.shape)
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PoolFormer implementation
"""
import os
import copy
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.models.layers.helpers import to_2tuple
try:
from mmseg.models.builder import BACKBONES as seg_BACKBONES
from mmseg.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmseg = True
except ImportError:
print("If for semantic segmentation, please install mmsegmentation first")
has_mmseg = False
try:
from mmdet.models.builder import BACKBONES as det_BACKBONES
from mmdet.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmdet = True
except ImportError:
print("If for detection, please install mmdetection first")
has_mmdet = False
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head',
**kwargs
}
default_cfgs = {
'poolformer_s': _cfg(crop_pct=0.9),
'poolformer_m': _cfg(crop_pct=0.95),
}
class PatchEmbed(nn.Module):
"""
Patch Embedding that is implemented by a layer of conv.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
"""
def __init__(self, patch_size=16, stride=16, padding=0,
in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
stride=stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
# 按照VIT模型来说输入图片(8, 3, 224, 224),输出(8, 768, 14, 14)
# 意思是将输入图片分为14X14个 (16, 16)大小的小块
class LayerNormChannel(nn.Module):
"""
LayerNorm only for Channel Dimension.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, eps=1e-05):
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x):
u = x.mean(1, keepdim=True) # 按行取平均值
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
+ self.bias.unsqueeze(-1).unsqueeze(-1)
return x
class GroupNorm(nn.GroupNorm):
"""
Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Pooling(nn.Module):
"""
Implementation of pooling for PoolFormer
--pool_size: pooling size
"""
def __init__(self, pool_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
def forward(self, x):
return self.pool(x) - x
class Mlp(nn.Module):
"""
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PoolFormerBlock(nn.Module):
"""
Implementation of one PoolFormer block.
--dim: embedding dim
--pool_size: pooling size
--mlp_ratio: mlp expansion ratio
--act_layer: activation
--norm_layer: normalization
--drop: dropout rate
--drop path: Stochastic Depth,
refer to https://arxiv.org/abs/1603.09382
--use_layer_scale, --layer_scale_init_value: LayerScale,
refer to https://arxiv.org/abs/2103.17239
"""
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = norm_layer(dim)
self.token_mixer = Pooling(pool_size=pool_size)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
# 这块写的是那个原文里的Y = TokenMixer(Norm(X)) + X
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.mlp(self.norm2(x)))
# 这写的是Z = σ(Norm(Y )W1)W2 + Y
else: # else不用看了,没用上
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def basic_blocks(dim, index, layers,
pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm,
drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
"""
generate PoolFormer blocks for a stage
return: PoolFormer blocks
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (
block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(PoolFormerBlock(
dim, pool_size=pool_size, mlp_ratio=mlp_ratio,
act_layer=act_layer, norm_layer=norm_layer,
drop=drop_rate, drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
blocks = nn.Sequential(*blocks)
return blocks
class PoolFormer(nn.Module):
"""
PoolFormer, the main class of our model
--layers: [x,x,x,x], number of blocks for the 4 stages
--embed_dims, --mlp_ratios, --pool_size: the embedding dims, mlp ratios and
pooling size for the 4 stages
--downsamples: flags to apply downsampling or not
--norm_layer, --act_layer: define the types of normalization and activation
--num_classes: number of classes for the image classification
--in_patch_size, --in_stride, --in_pad: specify the patch embedding
for the input image
--down_patch_size --down_stride --down_pad:
specify the downsample (patch embed.)
--fork_feat: whether output features of the 4 stages, for dense prediction
--init_cfg, --pretrained:
for mmdetection and mmsegmentation to load pretrained weights
"""
def __init__(self, layers, embed_dims=None,
mlp_ratios=None, downsamples=None,
pool_size=3,
norm_layer=GroupNorm, act_layer=nn.GELU,
num_classes=1000,
in_patch_size=7, in_stride=4, in_pad=2,
down_patch_size=3, down_stride=2, down_pad=1,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5,
fork_feat=False,
init_cfg=None,
pretrained=None,
**kwargs):
super().__init__()
if not fork_feat:
self.num_classes = num_classes
self.fork_feat = fork_feat
self.patch_embed = PatchEmbed(
patch_size=in_patch_size, stride=in_stride, padding=in_pad,
in_chans=3, embed_dim=embed_dims[0])
# set the main block in network
network = []
for i in range(len(layers)):
stage = basic_blocks(embed_dims[i], i, layers,
pool_size=pool_size, mlp_ratio=mlp_ratios[i],
act_layer=act_layer, norm_layer=norm_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value)
network.append(stage)
if i >= len(layers) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
network.append(
PatchEmbed(
patch_size=down_patch_size, stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
if self.fork_feat:
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
# TODO: more elegant way
"""For RetinaNet, `start_level=1`. The first norm layer will not used.
cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
"""
layer = nn.Identity()
else:
layer = norm_layer(embed_dims[i_emb])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
# Classifier head
self.norm = norm_layer(embed_dims[-1])
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
self.apply(self.cls_init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
# load pre-trained model
if self.fork_feat and (
self.init_cfg is not None or pretrained is not None):
self.init_weights()
# init for classification
def cls_init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
# init for mmdetection or mmsegmentation by loading
# imagenet pre-trained weights
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
missing_keys, unexpected_keys = \
self.load_state_dict(state_dict, False)
# show for debug
# print('missing_keys: ', missing_keys)
# print('unexpected_keys: ', unexpected_keys)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_embeddings(self, x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
# output the features of four stages for dense prediction
return outs
# output only the features of last layer for image classification
return x
def forward(self, x):
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
if self.fork_feat:
# otuput features of four stages for dense prediction
return x
x = self.norm(x)
cls_out = self.head(x.mean([-2, -1]))
# for image classification
return cls_out
model_urls = {
"poolformer_s12": "https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar",
"poolformer_s24": "https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar",
"poolformer_s36": "https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar",
"poolformer_m36": "https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar",
"poolformer_m48": "https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar",
}
@register_model
def poolformer_s12(pretrained=False, **kwargs):
"""
PoolFormer-S12 model, Params: 12M
--layers: [x,x,x,x], numbers of layers for the four stages
--embed_dims, --mlp_ratios:
embedding dims and mlp ratios for the four stages
--downsamples: flags to apply downsampling or not in four blocks
"""
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = PoolFormer(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
**kwargs)
model.default_cfg = default_cfgs['poolformer_s']
if pretrained:
url = model_urls['poolformer_s12']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def poolformer_s24(pretrained=False, **kwargs):
"""
PoolFormer-S24 model, Params: 21M
"""
layers = [4, 4, 12, 4]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = PoolFormer(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
**kwargs)
model.default_cfg = default_cfgs['poolformer_s']
if pretrained:
url = model_urls['poolformer_s24']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def poolformer_s36(pretrained=False, **kwargs):
"""
PoolFormer-S36 model, Params: 31M
"""
layers = [6, 6, 18, 6]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = PoolFormer(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['poolformer_s']
if pretrained:
url = model_urls['poolformer_s36']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def poolformer_m36(pretrained=False, **kwargs):
"""
PoolFormer-M36 model, Params: 56M
"""
layers = [6, 6, 18, 6]
embed_dims = [96, 192, 384, 768]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = PoolFormer(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['poolformer_m']
if pretrained:
url = model_urls['poolformer_m36']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def poolformer_m48(pretrained=False, **kwargs):
"""
PoolFormer-M48 model, Params: 73M
"""
layers = [8, 8, 24, 8]
embed_dims = [96, 192, 384, 768]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
model = PoolFormer(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['poolformer_m']
if pretrained:
url = model_urls['poolformer_m48']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
if has_mmseg and has_mmdet:
"""
The following models are for dense prediction based on
mmdetection and mmsegmentation
"""
@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_s12_feat(PoolFormer):
"""
PoolFormer-S12 model, Params: 12M
"""
def __init__(self, **kwargs):
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
fork_feat=True,
**kwargs)
@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_s24_feat(PoolFormer):
"""
PoolFormer-S24 model, Params: 21M
"""
def __init__(self, **kwargs):
layers = [4, 4, 12, 4]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
fork_feat=True,
**kwargs)
@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_s36_feat(PoolFormer):
"""
PoolFormer-S36 model, Params: 31M
"""
def __init__(self, **kwargs):
layers = [6, 6, 18, 6]
embed_dims = [64, 128, 320, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
fork_feat=True,
**kwargs)
@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_m36_feat(PoolFormer):
"""
PoolFormer-S36 model, Params: 56M
"""
def __init__(self, **kwargs):
layers = [6, 6, 18, 6]
embed_dims = [96, 192, 384, 768]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
fork_feat=True,
**kwargs)
@seg_BACKBONES.register_module()
@det_BACKBONES.register_module()
class poolformer_m48_feat(PoolFormer):
"""
PoolFormer-M48 model, Params: 73M
"""
def __init__(self, **kwargs):
layers = [8, 8, 24, 8]
embed_dims = [96, 192, 384, 768]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
super().__init__(
layers, embed_dims=embed_dims,
mlp_ratios=mlp_ratios, downsamples=downsamples,
layer_scale_init_value=1e-6,
fork_feat=True,
**kwargs)