import torch
import torch.nn as nn
from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ConvMixerLayer(nn.Module):
def __init__(self, dim, kernel_size=9):
super().__init__()
self.Resnet = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding='same'),
nn.GELU(),
nn.BatchNorm2d(dim)
)
self.Conv_1x1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1),
nn.GELU(),
nn.BatchNorm2d(dim)
)
def forward(self, x):
x = x + self.Resnet(x)
x = self.Conv_1x1(x)
return x
class ConvMixer(nn.Module):
def __init__(self, dim=512, depth=5, kernel_size=5, patch_size=7, n_classes=1000):
super().__init__()
self.conv2d1 = nn.Sequential(
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
nn.GELU(),
nn.BatchNorm2d(dim)
)
self.ConvMixer_blocks = nn.ModuleList([])
for _ in range(depth):
self.ConvMixer_blocks.append(ConvMixerLayer(dim=dim, kernel_size=kernel_size))
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(dim, n_classes)
)
def forward(self, x):
x = self.conv2d1(x)
for ConvMixer_block in self.ConvMixer_blocks:
x = ConvMixer_block(x)
x = self.head(x)
return x
#
if __name__ == '__main__':
model = ConvMixer(dim=512, depth=2).to(device)
print(model)
summary(model, (3, 224, 224))
运行结果:
ConvMixer(
(conv2d1): Sequential(
(0): Conv2d(3, 512, kernel_size=(7, 7), stride=(7, 7))
(1): GELU()
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(ConvMixer_blocks): ModuleList(
(0): ConvMixerLayer(
(Resnet): Sequential(
(0): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=same, groups=512)
(1): GELU()
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(Conv_1x1): Sequential(
(0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(1): GELU()
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ConvMixerLayer(
(Resnet): Sequential(
(0): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=same, groups=512)
(1): GELU()
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(Conv_1x1): Sequential(
(0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(1): GELU()
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(head): Sequential(
(0): AdaptiveAvgPool2d(output_size=(1, 1))
(1): Flatten(start_dim=1, end_dim=-1)
(2): Linear(in_features=512, out_features=1000, bias=True)
)
)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 512, 32, 32] 75,776
GELU-2 [-1, 512, 32, 32] 0
BatchNorm2d-3 [-1, 512, 32, 32] 1,024
Conv2d-4 [-1, 512, 32, 32] 13,312
GELU-5 [-1, 512, 32, 32] 0
BatchNorm2d-6 [-1, 512, 32, 32] 1,024
Conv2d-7 [-1, 512, 32, 32] 262,656
GELU-8 [-1, 512, 32, 32] 0
BatchNorm2d-9 [-1, 512, 32, 32] 1,024
ConvMixerLayer-10 [-1, 512, 32, 32] 0
Conv2d-11 [-1, 512, 32, 32] 13,312
GELU-12 [-1, 512, 32, 32] 0
BatchNorm2d-13 [-1, 512, 32, 32] 1,024
Conv2d-14 [-1, 512, 32, 32] 262,656
GELU-15 [-1, 512, 32, 32] 0
BatchNorm2d-16 [-1, 512, 32, 32] 1,024
ConvMixerLayer-17 [-1, 512, 32, 32] 0
AdaptiveAvgPool2d-18 [-1, 512, 1, 1] 0
Flatten-19 [-1, 512] 0
Linear-20 [-1, 1000] 513,000
================================================================
Total params: 1,145,832
Trainable params: 1,145,832
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 68.02
Params size (MB): 4.37
Estimated Total Size (MB): 72.96
----------------------------------------------------------------