文献的标题是《VanillaNet: the Power of Minimalism in Deep Learning》,由Hanting Chen、Yunhe Wang、Jianyuan Guo和Dacheng Tao撰写,分别来自华为诺亚方舟实验室和悉尼大学计算机学院。
过去几十年,人工神经网络取得了显著的进展,特别是通过增加网络复杂性来提高性能。例如,AlexNet和ResNet在图像识别基准测试中取得了突破性的成绩。然而,随着网络复杂性的增加,其部署的挑战也在增加。例如,ResNet中的快捷操作消耗了大量的内存流量,而像Swin Transformer中的移动窗口自注意力这样的复杂操作需要复杂的工程实现。
import torch
import torch.nn as nn
from timm.layers import weight_init
__all__ = ['vanillanet_5', 'vanillanet_6', 'vanillanet_7', 'vanillanet_8', 'vanillanet_9', 'vanillanet_10',
'vanillanet_11', 'vanillanet_12', 'vanillanet_13', 'vanillanet_13_x1_5', 'vanillanet_13_x1_5_ada_pool']
class activation(nn.ReLU):
def __init__(self, dim, act_num=3, deploy=False):
super(activation, self).__init__()
self.deploy = deploy
self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num * 2 + 1, act_num * 2 + 1))
self.bias = None
self.bn = nn.BatchNorm2d(dim, eps=1e-6)
self.dim = dim
self.act_num = act_num
weight_init.trunc_normal_(self.weight, std=.02)
def forward(self, x):
if self.deploy:
return torch.nn.functional.conv2d(
super(activation, self).forward(x),
self.weight, self.bias, padding=(self.act_num * 2 + 1) // 2, groups=self.dim)
return self.bn(torch.nn.functional.conv2d(
super(activation, self).forward(x),
self.weight, padding=self.act_num, groups=self.dim))
def _fuse_bn_tensor(self, weight, bn):
kernel = weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta + (0 - running_mean) * gamma / std
def switch_to_deploy(self):
if not self.deploy:
kernel, bias = self._fuse_bn_tensor(self.weight, self.bn)
self.weight.data = kernel
self.bias = torch.nn.Parameter(torch.zeros(self.dim))
self.bias.data = bias
self.deploy = True
class Block(nn.Module):
def __init__(self, dim, dim_out, act_num=3, stride=2, deploy=False, ada_pool=None):
self.act_learn = 1
self.deploy = deploy
if self.deploy:
self.conv = nn.Conv2d(dim, dim_out, kernel_size=1)
self.conv1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1),
nn.BatchNorm2d(dim, eps=1e-6),
self.conv2 = nn.Sequential(
nn.Conv2d(dim, dim_out, kernel_size=1),
nn.BatchNorm2d(dim_out, eps=1e-6)
if not ada_pool:
self.pool = nn.Identity() if stride == 1 else nn.MaxPool2d(stride)
self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool))
self.act = activation(dim_out, act_num)
def forward(self, x):
if self.deploy:
x = self.conv(x)
x = self.conv1(x)
x = torch.nn.functional.leaky_relu(x, self.act_learn)
x = self.conv2(x)
x = self.pool(x)
x = self.act(x)
return x
def _fuse_bn_tensor(self, conv, bn):
kernel = conv.weight
bias = conv.bias
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta + (bias - running_mean) * gamma / std
def switch_to_deploy(self):
if not self.deploy:
kernel, bias = self._fuse_bn_tensor(self.conv1[0], self.conv1[1])
self.conv1[0].weight.data = kernel
self.conv1[0].bias.data = bias
# kernel, bias = self.conv2[0].weight.data, self.conv2[0].bias.data
kernel, bias = self._fuse_bn_tensor(self.conv2[0], self.conv2[1])
self.conv = self.conv2[0]
self.conv.weight.data = torch.matmul(kernel.transpose(1, 3),
self.conv1[0].weight.data.squeeze(3).squeeze(2)).transpose(1, 3)
self.conv.bias.data = bias + (self.conv1[0].bias.data.view(1, -1, 1, 1) * kernel).sum(3).sum(2).sum(1)
self.deploy = True
class VanillaNet(nn.Module):
def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768],
drop_rate=0, act_num=3, strides=[2, 2, 2, 1], deploy=False, ada_pool=None, **kwargs):
self.deploy = deploy
if self.deploy:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
activation(dims[0], act_num)
self.stem1 = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
nn.BatchNorm2d(dims[0], eps=1e-6),
self.stem2 = nn.Sequential(
nn.Conv2d(dims[0], dims[0], kernel_size=1, stride=1),
nn.BatchNorm2d(dims[0], eps=1e-6),
activation(dims[0], act_num)
self.act_learn = 1
self.stages = nn.ModuleList()
for i in range(len(strides)):
if not ada_pool:
stage = Block(dim=dims[i], dim_out=dims[i + 1], act_num=act_num, stride=strides[i], deploy=deploy)
stage = Block(dim=dims[i], dim_out=dims[i + 1], act_num=act_num, stride=strides[i], deploy=deploy,
self.depth = len(strides)
self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
weight_init.trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def change_act(self, m):
for i in range(self.depth):
self.stages[i].act_learn = m
self.act_learn = m
def forward(self, x):
results = []
if self.deploy:
x = self.stem(x)
x = self.stem1(x)
x = torch.nn.functional.leaky_relu(x, self.act_learn)
x = self.stem2(x)
for i in range(self.depth):
x = self.stages[i](x)
return results
def _fuse_bn_tensor(self, conv, bn):
kernel = conv.weight
bias = conv.bias
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta + (bias - running_mean) * gamma / std
def switch_to_deploy(self):
if not self.deploy:
kernel, bias = self._fuse_bn_tensor(self.stem1[0], self.stem1[1])
self.stem1[0].weight.data = kernel
self.stem1[0].bias.data = bias
kernel, bias = self._fuse_bn_tensor(self.stem2[0], self.stem2[1])
self.stem1[0].weight.data = torch.einsum('oi,icjk->ocjk', kernel.squeeze(3).squeeze(2),
self.stem1[0].bias.data = bias + (self.stem1[0].bias.data.view(1, -1, 1, 1) * kernel).sum(3).sum(2).sum(1)
self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[2]])
for i in range(self.depth):
self.deploy = True
def vanillanet_5(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 256 * 4, 512 * 4, 1024 * 4], strides=[2, 2, 2], **kwargs)
return model
def vanillanet_6(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 256 * 4, 512 * 4, 1024 * 4, 1024 * 4], strides=[2, 2, 2, 1], **kwargs)
return model
def vanillanet_7(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 1024 * 4, 1024 * 4], strides=[1, 2, 2, 2, 1], **kwargs)
return model
def vanillanet_8(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 2, 1], **kwargs)
return model
def vanillanet_9(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 2, 1], **kwargs)
return model
def vanillanet_10(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 2, 1],
return model
def vanillanet_11(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 1, 2, 1],
return model
def vanillanet_12(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 2, 1],
return model
def vanillanet_13(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4,
1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1],
return model
def vanillanet_13_x1_5(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 6, 128 * 6, 256 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 1024 * 6,
1024 * 6],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1],
return model
def vanillanet_13_x1_5_ada_pool(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 6, 128 * 6, 256 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 1024 * 6,
1024 * 6],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1],
ada_pool=[0, 38, 19, 0, 0, 0, 0, 0, 0, 10, 0],
return model
Ultralytics 8.3.0 🚀 Python-3.9.20 torch-2.0.1 CUDA:0 (NVIDIA GeForce RTX 3060 Laptop GPU, 12288MiB)
YOLO11-VanillaNet summary: 205 layers, 29,772,192 parameters, 0 gradients, 150.5 GFLOPs
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 8/8 [00:04<00:00, 1.75it/s]
all 128 929 0.797 0.614 0.695 0.565
person 61 254 0.96 0.568 0.788 0.603
bicycle 3 6 0.893 0.167 0.201 0.179
car 12 46 0.763 0.21 0.256 0.151
motorcycle 4 5 0.73 1 0.995 0.902
airplane 5 6 0.824 0.788 0.913 0.692
bus 5 7 0.821 0.714 0.718 0.647
train 3 3 0.445 1 0.863 0.642
truck 5 12 0.864 0.5 0.522 0.425
boat 2 6 0.586 0.476 0.554 0.379
traffic light 4 14 0.725 0.143 0.267 0.188
stop sign 2 2 1 0 0.595 0.526
bench 5 9 0.409 0.556 0.513 0.448
bird 2 16 0.957 0.938 0.991 0.69
cat 4 4 0.844 0.75 0.774 0.709
dog 9 9 0.987 0.778 0.899 0.718
horse 1 2 0.918 1 0.995 0.921
elephant 4 17 0.827 0.941 0.902 0.757
bear 1 1 0.802 1 0.995 0.995
zebra 2 4 0.685 1 0.995 0.945
giraffe 4 9 0.62 1 0.956 0.801
backpack 4 6 1 0.533 0.669 0.619
umbrella 4 18 0.711 0.944 0.921 0.716
handbag 9 19 1 0.218 0.365 0.295
tie 6 7 0.715 0.571 0.688 0.492
suitcase 2 4 0.861 0.75 0.776 0.68
frisbee 5 5 0.668 0.4 0.631 0.463
skis 1 1 0 0 0 0
snowboard 2 7 0.81 0.857 0.944 0.684
sports ball 6 6 0.943 0.167 0.171 0.137
kite 2 10 1 0.128 0.315 0.083
baseball bat 4 4 1 0 0.264 0.234
baseball glove 4 7 1 0.191 0.289 0.261
skateboard 3 5 0.979 0.4 0.648 0.43
tennis racket 5 7 1 0.365 0.464 0.381
bottle 6 18 0.897 0.222 0.24 0.212
wine glass 5 16 0.668 0.755 0.733 0.423
cup 10 36 0.838 0.431 0.631 0.379
fork 6 6 0.606 0.273 0.471 0.367
knife 7 16 1 0.366 0.605 0.323
spoon 5 22 0.888 0.36 0.542 0.333
bowl 9 28 0.951 0.69 0.743 0.621
banana 1 1 0.839 1 0.995 0.995
sandwich 2 2 0.761 1 0.995 0.92
orange 1 4 0.45 1 0.995 0.805
broccoli 4 11 0.727 0.455 0.514 0.418
carrot 3 24 0.442 0.75 0.703 0.529
hot dog 1 2 0.825 1 0.995 0.921
pizza 5 5 0.89 0.8 0.862 0.783
donut 2 14 1 0.974 0.995 0.811
cake 4 4 1 0.479 0.794 0.721
chair 9 35 0.86 0.703 0.889 0.646
couch 5 6 0.563 0.438 0.788 0.634
potted plant 9 14 0.844 0.929 0.945 0.744
bed 3 3 0.931 1 0.995 0.895
dining table 10 13 0.857 0.615 0.756 0.608
toilet 2 2 0.878 1 0.995 0.895
tv 2 2 0.909 1 0.995 0.895
laptop 2 3 0.525 0.333 0.525 0.451
mouse 2 2 0 0 0 0
remote 5 8 1 0.236 0.407 0.343
cell phone 5 8 0.735 0.125 0.154 0.122
microwave 3 3 0.929 1 0.995 0.764
oven 5 5 0.917 0.8 0.797 0.657
sink 4 6 0.751 0.505 0.65 0.555
refrigerator 5 5 0.607 0.8 0.906 0.631
book 6 29 0.754 0.422 0.576 0.278
clock 8 9 1 0.737 0.874 0.762
vase 2 2 0.762 0.5 0.521 0.414
scissors 1 1 0.86 1 0.995 0.796
teddy bear 6 21 0.948 0.878 0.971 0.809
toothbrush 2 5 0.853 1 0.995 0.849
Speed: 0.8ms preprocess, 22.1ms inference, 0.0ms loss, 1.7ms postprocess per image
Results saved to runs\train\exp_yolo11-VanillaNet_datasets_coco1283