文献的标题是《VanillaNet: the Power of Minimalism in Deep Learning》,由Hanting Chen、Yunhe Wang、Jianyuan Guo和Dacheng Tao撰写,分别来自华为诺亚方舟实验室和悉尼大学计算机学院。
**摘要**:
文章提出了一种新的神经网络架构——VanillaNet,它体现了极简主义的设计理念。VanillaNet避免了深度网络、快捷连接和复杂操作(如自注意力机制),使得网络结构简洁而强大。该网络的每一层都经过精心设计,以保持紧凑和直接,同时在训练后剪枝非线性激活函数以恢复原始架构。VanillaNet克服了复杂性的挑战,非常适合资源受限的环境。通过广泛的实验,证明了VanillaNet在性能上与著名的深度神经网络和视觉变换器相当,展示了深度学习中极简主义的力量。
**引言**:
过去几十年,人工神经网络取得了显著的进展,特别是通过增加网络复杂性来提高性能。例如,AlexNet和ResNet在图像识别基准测试中取得了突破性的成绩。然而,随着网络复杂性的增加,其部署的挑战也在增加。例如,ResNet中的快捷操作消耗了大量的内存流量,而像Swin Transformer中的移动窗口自注意力这样的复杂操作需要复杂的工程实现。
**VanillaNet架构**:
VanillaNet的设计遵循了神经网络的基本设计原则,包括一个茎块、一个主体和一个用于分类输出的全连接层。与现有深度网络不同,VanillaNet在每个阶段只使用一层来构建一个非常简单的网络。VanillaNet没有快捷方式,因为作者发现添加快捷方式对性能提升很小。
**VanillaNet的训练**:
为了提高VanillaNet的性能,作者提出了一种深度训练策略,即在训练初期使用两个带有激活函数的卷积层代替一个卷积层。随着训练的进行,逐渐减少这些非线性层,最终将两个卷积层合并,以减少推理时间。
**实验**:
作者在ImageNet数据集上进行了实验,验证了VanillaNet的性能。通过消融研究,作者发现所提出的模块(包括系列激活函数和深度训练技术)是有效的。此外,作者还可视化了VanillaNet的特征,以进一步研究网络是如何从图像中学习的。
**结论**:
文章全面研究了建立高性能但无复杂架构(如快捷方式、高深度和注意力层)的神经网络的可行性,体现了向简单性和优雅性转变的设计范式。实验结果表明,VanillaNet在大规模图像分类数据集上的性能与众所周知的深度神经网络和视觉变换器相当,从而突显了深度学习中极简主义的潜力。
import torch
import torch.nn as nn
from timm.layers import weight_init
__all__ = ['vanillanet_5', 'vanillanet_6', 'vanillanet_7', 'vanillanet_8', 'vanillanet_9', 'vanillanet_10',
'vanillanet_11', 'vanillanet_12', 'vanillanet_13', 'vanillanet_13_x1_5', 'vanillanet_13_x1_5_ada_pool']
class activation(nn.ReLU):
def __init__(self, dim, act_num=3, deploy=False):
super(activation, self).__init__()
self.deploy = deploy
self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num * 2 + 1, act_num * 2 + 1))
self.bias = None
self.bn = nn.BatchNorm2d(dim, eps=1e-6)
self.dim = dim
self.act_num = act_num
weight_init.trunc_normal_(self.weight, std=.02)
def forward(self, x):
if self.deploy:
return torch.nn.functional.conv2d(
super(activation, self).forward(x),
self.weight, self.bias, padding=(self.act_num * 2 + 1) // 2, groups=self.dim)
else:
return self.bn(torch.nn.functional.conv2d(
super(activation, self).forward(x),
self.weight, padding=self.act_num, groups=self.dim))
def _fuse_bn_tensor(self, weight, bn):
kernel = weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta + (0 - running_mean) * gamma / std
def switch_to_deploy(self):
if not self.deploy:
kernel, bias = self._fuse_bn_tensor(self.weight, self.bn)
self.weight.data = kernel
self.bias = torch.nn.Parameter(torch.zeros(self.dim))
self.bias.data = bias
self.__delattr__('bn')
self.deploy = True
class Block(nn.Module):
def __init__(self, dim, dim_out, act_num=3, stride=2, deploy=False, ada_pool=None):
super().__init__()
self.act_learn = 1
self.deploy = deploy
if self.deploy:
self.conv = nn.Conv2d(dim, dim_out, kernel_size=1)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1),
nn.BatchNorm2d(dim, eps=1e-6),
)
self.conv2 = nn.Sequential(
nn.Conv2d(dim, dim_out, kernel_size=1),
nn.BatchNorm2d(dim_out, eps=1e-6)
)
if not ada_pool:
self.pool = nn.Identity() if stride == 1 else nn.MaxPool2d(stride)
else:
self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool))
self.act = activation(dim_out, act_num)
def forward(self, x):
if self.deploy:
x = self.conv(x)
else:
x = self.conv1(x)
x = torch.nn.functional.leaky_relu(x, self.act_learn)
x = self.conv2(x)
x = self.pool(x)
x = self.act(x)
return x
def _fuse_bn_tensor(self, conv, bn):
kernel = conv.weight
bias = conv.bias
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta + (bias - running_mean) * gamma / std
def switch_to_deploy(self):
if not self.deploy:
kernel, bias = self._fuse_bn_tensor(self.conv1[0], self.conv1[1])
self.conv1[0].weight.data = kernel
self.conv1[0].bias.data = bias
# kernel, bias = self.conv2[0].weight.data, self.conv2[0].bias.data
kernel, bias = self._fuse_bn_tensor(self.conv2[0], self.conv2[1])
self.conv = self.conv2[0]
self.conv.weight.data = torch.matmul(kernel.transpose(1, 3),
self.conv1[0].weight.data.squeeze(3).squeeze(2)).transpose(1, 3)
self.conv.bias.data = bias + (self.conv1[0].bias.data.view(1, -1, 1, 1) * kernel).sum(3).sum(2).sum(1)
self.__delattr__('conv1')
self.__delattr__('conv2')
self.act.switch_to_deploy()
self.deploy = True
class VanillaNet(nn.Module):
def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768],
drop_rate=0, act_num=3, strides=[2, 2, 2, 1], deploy=False, ada_pool=None, **kwargs):
super().__init__()
self.deploy = deploy
if self.deploy:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
activation(dims[0], act_num)
)
else:
self.stem1 = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
nn.BatchNorm2d(dims[0], eps=1e-6),
)
self.stem2 = nn.Sequential(
nn.Conv2d(dims[0], dims[0], kernel_size=1, stride=1),
nn.BatchNorm2d(dims[0], eps=1e-6),
activation(dims[0], act_num)
)
self.act_learn = 1
self.stages = nn.ModuleList()
for i in range(len(strides)):
if not ada_pool:
stage = Block(dim=dims[i], dim_out=dims[i + 1], act_num=act_num, stride=strides[i], deploy=deploy)
else:
stage = Block(dim=dims[i], dim_out=dims[i + 1], act_num=act_num, stride=strides[i], deploy=deploy,
ada_pool=ada_pool[i])
self.stages.append(stage)
self.depth = len(strides)
self.apply(self._init_weights)
self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
weight_init.trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def change_act(self, m):
for i in range(self.depth):
self.stages[i].act_learn = m
self.act_learn = m
def forward(self, x):
results = []
if self.deploy:
x = self.stem(x)
else:
x = self.stem1(x)
x = torch.nn.functional.leaky_relu(x, self.act_learn)
x = self.stem2(x)
results.append(x)
for i in range(self.depth):
x = self.stages[i](x)
results.append(x)
return results
def _fuse_bn_tensor(self, conv, bn):
kernel = conv.weight
bias = conv.bias
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta + (bias - running_mean) * gamma / std
def switch_to_deploy(self):
if not self.deploy:
self.stem2[2].switch_to_deploy()
kernel, bias = self._fuse_bn_tensor(self.stem1[0], self.stem1[1])
self.stem1[0].weight.data = kernel
self.stem1[0].bias.data = bias
kernel, bias = self._fuse_bn_tensor(self.stem2[0], self.stem2[1])
self.stem1[0].weight.data = torch.einsum('oi,icjk->ocjk', kernel.squeeze(3).squeeze(2),
self.stem1[0].weight.data)
self.stem1[0].bias.data = bias + (self.stem1[0].bias.data.view(1, -1, 1, 1) * kernel).sum(3).sum(2).sum(1)
self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[2]])
self.__delattr__('stem1')
self.__delattr__('stem2')
for i in range(self.depth):
self.stages[i].switch_to_deploy()
self.deploy = True
def vanillanet_5(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 256 * 4, 512 * 4, 1024 * 4], strides=[2, 2, 2], **kwargs)
return model
def vanillanet_6(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 256 * 4, 512 * 4, 1024 * 4, 1024 * 4], strides=[2, 2, 2, 1], **kwargs)
return model
def vanillanet_7(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 1024 * 4, 1024 * 4], strides=[1, 2, 2, 2, 1], **kwargs)
return model
def vanillanet_8(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 2, 1], **kwargs)
return model
def vanillanet_9(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 2, 1], **kwargs)
return model
def vanillanet_10(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 2, 1],
**kwargs)
return model
def vanillanet_11(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 1, 2, 1],
**kwargs)
return model
def vanillanet_12(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4, 1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 2, 1],
**kwargs)
return model
def vanillanet_13(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 4, 128 * 4, 256 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 512 * 4, 1024 * 4,
1024 * 4],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1],
**kwargs)
return model
def vanillanet_13_x1_5(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 6, 128 * 6, 256 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 1024 * 6,
1024 * 6],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1],
**kwargs)
return model
def vanillanet_13_x1_5_ada_pool(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128 * 6, 128 * 6, 256 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 512 * 6, 1024 * 6,
1024 * 6],
strides=[1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1],
ada_pool=[0, 38, 19, 0, 0, 0, 0, 0, 0, 10, 0],
**kwargs)
return model
代码运行结果
Ultralytics 8.3.0 🚀 Python-3.9.20 torch-2.0.1 CUDA:0 (NVIDIA GeForce RTX 3060 Laptop GPU, 12288MiB)
YOLO11-VanillaNet summary: 205 layers, 29,772,192 parameters, 0 gradients, 150.5 GFLOPs
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 8/8 [00:04<00:00, 1.75it/s]
all 128 929 0.797 0.614 0.695 0.565
person 61 254 0.96 0.568 0.788 0.603
bicycle 3 6 0.893 0.167 0.201 0.179
car 12 46 0.763 0.21 0.256 0.151
motorcycle 4 5 0.73 1 0.995 0.902
airplane 5 6 0.824 0.788 0.913 0.692
bus 5 7 0.821 0.714 0.718 0.647
train 3 3 0.445 1 0.863 0.642
truck 5 12 0.864 0.5 0.522 0.425
boat 2 6 0.586 0.476 0.554 0.379
traffic light 4 14 0.725 0.143 0.267 0.188
stop sign 2 2 1 0 0.595 0.526
bench 5 9 0.409 0.556 0.513 0.448
bird 2 16 0.957 0.938 0.991 0.69
cat 4 4 0.844 0.75 0.774 0.709
dog 9 9 0.987 0.778 0.899 0.718
horse 1 2 0.918 1 0.995 0.921
elephant 4 17 0.827 0.941 0.902 0.757
bear 1 1 0.802 1 0.995 0.995
zebra 2 4 0.685 1 0.995 0.945
giraffe 4 9 0.62 1 0.956 0.801
backpack 4 6 1 0.533 0.669 0.619
umbrella 4 18 0.711 0.944 0.921 0.716
handbag 9 19 1 0.218 0.365 0.295
tie 6 7 0.715 0.571 0.688 0.492
suitcase 2 4 0.861 0.75 0.776 0.68
frisbee 5 5 0.668 0.4 0.631 0.463
skis 1 1 0 0 0 0
snowboard 2 7 0.81 0.857 0.944 0.684
sports ball 6 6 0.943 0.167 0.171 0.137
kite 2 10 1 0.128 0.315 0.083
baseball bat 4 4 1 0 0.264 0.234
baseball glove 4 7 1 0.191 0.289 0.261
skateboard 3 5 0.979 0.4 0.648 0.43
tennis racket 5 7 1 0.365 0.464 0.381
bottle 6 18 0.897 0.222 0.24 0.212
wine glass 5 16 0.668 0.755 0.733 0.423
cup 10 36 0.838 0.431 0.631 0.379
fork 6 6 0.606 0.273 0.471 0.367
knife 7 16 1 0.366 0.605 0.323
spoon 5 22 0.888 0.36 0.542 0.333
bowl 9 28 0.951 0.69 0.743 0.621
banana 1 1 0.839 1 0.995 0.995
sandwich 2 2 0.761 1 0.995 0.92
orange 1 4 0.45 1 0.995 0.805
broccoli 4 11 0.727 0.455 0.514 0.418
carrot 3 24 0.442 0.75 0.703 0.529
hot dog 1 2 0.825 1 0.995 0.921
pizza 5 5 0.89 0.8 0.862 0.783
donut 2 14 1 0.974 0.995 0.811
cake 4 4 1 0.479 0.794 0.721
chair 9 35 0.86 0.703 0.889 0.646
couch 5 6 0.563 0.438 0.788 0.634
potted plant 9 14 0.844 0.929 0.945 0.744
bed 3 3 0.931 1 0.995 0.895
dining table 10 13 0.857 0.615 0.756 0.608
toilet 2 2 0.878 1 0.995 0.895
tv 2 2 0.909 1 0.995 0.895
laptop 2 3 0.525 0.333 0.525 0.451
mouse 2 2 0 0 0 0
remote 5 8 1 0.236 0.407 0.343
cell phone 5 8 0.735 0.125 0.154 0.122
microwave 3 3 0.929 1 0.995 0.764
oven 5 5 0.917 0.8 0.797 0.657
sink 4 6 0.751 0.505 0.65 0.555
refrigerator 5 5 0.607 0.8 0.906 0.631
book 6 29 0.754 0.422 0.576 0.278
clock 8 9 1 0.737 0.874 0.762
vase 2 2 0.762 0.5 0.521 0.414
scissors 1 1 0.86 1 0.995 0.796
teddy bear 6 21 0.948 0.878 0.971 0.809
toothbrush 2 5 0.853 1 0.995 0.849
Speed: 0.8ms preprocess, 22.1ms inference, 0.0ms loss, 1.7ms postprocess per image
Results saved to runs\train\exp_yolo11-VanillaNet_datasets_coco1283
实验总结,该主干网络非常吃内存,不过精度结果比FastNet要提升不少