SeNet学习笔记及仿真
前言
SENet[1]是ImageNet 2017年的冠军模型,自SeNet提出后,ImageNet挑战赛就停止举办了。SENet同之前的ResNet一样,引入了一些技巧,可以在很大程度上降低模型的参数,并且提升模型的运算速度。
SENet全称Squeeze-and-Excitation Networks,中文名可以翻译为挤压和激励网络。SENet在ImageNet 2017取得了第一名的成绩,Top-5 error rate降低到了2.251%,官方的模型和代码在github仓库中可以找到[2]。
SE block
SENet提出的动机是将通道之间的关系结合起来,于是引出了一个Squeeze-and-excitation(SE)块[1],它的目的就是通过显式建模网络卷积特征的信道之间的相互依赖性来提高网络表征的质量。SE块的机制也可以说是通过学习全局信息来选择性地强调有用的特征和抑制不太有用的特征,SENet块如fig1所示。
SE模块可以看作是一个计算单元,用 F t r F_{tr} Ftr 表示,可以将输入 X ∈ R H ′ × W ′ × C ′ X \in \R^{H' \times W' \times C'} X∈RH′×W′×C′ 映射为特征图 U ∈ R H × W × C U \in \R^{H \times W \times C} U∈RH×W×C。以下的符号中, F t r F_{tr} Ftr 表示卷积操作, V = [ V 1 , V 2 , … , V C ] \bold{V}=[V_1, V_2, \dots, V_C] V=[V1,V2,…,VC] 来表示学习到的一组滤波器核,其中 V c V_c Vc 表示的是第 c c c 个滤波器的参数,所以输出可以表示为 U = [ U 1 , U 2 , … , U C ] \bold{U}=[U_1, U_2, \dots, U_C] U=[U1,U2,…,UC],其中:
U c = V c ∗ X = ∑ s = 1 C ′ V c s ∗ X s U_c=V_c * \bold{X}=\sum_{s=1}^{C'}V_c^s * X^s Uc=Vc∗X=s=1∑C′Vcs∗Xs
公式1中 ∗ * ∗ 表示的是卷积操作, V c = [ V c 1 , V c 2 , … , V c C ′ ] , X = [ X 1 , X 2 , … , X C ′ ] V_c=[V_c^1, V_c^2, \dots, V_c^{C'}], \quad \bold{X}=[X^1, X^2, \dots, X^{C'}] Vc=[Vc1,Vc2,…,VcC′],X=[X1,X2,…,XC′] 以及 u c ∈ R H × W u_c \in \R^{H \times W} uc∈RH×W, V c s V_c^s Vcs 表示的是 X \bold{X} X 对应单个 V c V_c Vc 通道的 2D 空间核。
对于以上公式有以下的说明:
- 为了简化符号表达,省略了偏差项
- 从以上的卷积公式可以看出,各个通道的卷积进行了求和操作,所以通道的特征信息和卷积核学习到的空间关系混合到一起,所以需要分离两个特征信息,让模型学习到通道的特征关系
Squeeze: Global Information Embedding
为了解决通道依赖的问题,需要考虑将输出特征中每个通道对应的信号。每一个训练的滤波器都有一个局部感受野,因此每个神经元的转换输出都不能很好地利用这个区域之外的上下文信息。
为了解决这个问题,SeNet 将全局空间信息压缩到通道描述符中,这是通过使用全局平均池化(global average pooling)来生成通道统计数据来实现的。形式上,统计量 Z ∈ R C Z \in \R^C Z∈RC 是通过收缩 U U U 的空间维度 H × W H \times W H×W 来生成的,从而 Z Z Z 的第 c c c 个元素通过以下方式计算:
z c = F s q ( u c ) = 1 H × W ∑ i = 1 H ∑ j = 1 W u c ( i , j ) z_c = F_{sq}(u_c)=\frac{1}{H \times W} \sum_{i=1}^H \sum_{j=1}^W u_c (i, j) zc=Fsq(uc)=H×W1i=1∑Hj=1∑Wuc(i,j)
Excitation: Adaptive Recalibration
为了利用在 Squeeze 操作中聚集到的信息,接下来进行第二个操作,目的是为了完全捕获通道依赖信息。为了实现这一目标,该功能必须满足两个标准:
- 它必须要是灵活的,特别地,它必须能够学习通道之间的非线性相互作用
- 它必须学习一种非互斥的关系,因为我们希望确保允许强调多个通道
为了满足这些标准,这里选择了带有 sigmoid 激活函数的简单门控机制:
s = F e x ( z , W ) = σ ( g ( z , W ) ) = σ ( W 2 δ ( W 1 z ) ) s = F_{ex}(z, W) = \sigma (g(z, \bold{W}))=\sigma(\bold{W}_2 \delta(\bold{W}_1 z)) s=Fex(z,W)=σ(g(z,W))=σ(W2δ(W1z))
其中 δ \delta δ 表示的是 ReLU 函数, W 1 ∈ R C r × C , W 2 ∈ R C × C r \bold{W}_1 \in \R^{\frac{C}{r} \times C} ,\quad \bold{W}_2 \in \R^{C \times \frac{C}{r}} W1∈RrC×C,W2∈RC×rC 。为了降低模型复杂度以及提升泛化能力,这里用到了两个全连接层的bottleneck结构,其中第一个全连接层起到降维的作用,降维系数为r是个超参数,然后采用ReLU激活,最后的全连接层恢复原始的维度,最后将学习到的各个通道的激活值(sigmoid激活,值为0~1)乘上 U U U 上的原始特征:
x ~ c = F s c a l e ( u c , s c ) = s c ⋅ u c \tilde{x}_c = F_{scale}(u_c, s_c) = s_c \cdot u_c x~c=Fscale(uc,sc)=sc⋅uc
其中 X ~ = [ X ~ 1 , X ~ 2 , … , X ~ C ] \widetilde{\bold{X}}=[\widetilde{X}_1, \widetilde{X}_2, \dots, \widetilde{X}_C] X =[X 1,X 2,…,X C], F s c a l e ( u c , s c ) F_{scale}(u_c, s_c) Fscale(uc,sc) 表示的是标量 S c S_c Sc 和特征图 u c ∈ R H × W u_c \in \R^{H \times W} uc∈RH×W 的乘法
其实整个操作可以看做学习到了各个通道的权重参数,从而使得模型对各个通道的特征更加有辨别能力,这应该也算一种attention机制[3]
SE block的应用
SE模块十分灵活,可以直接应用到现用的网络架构中。例如GoogLeNet和ResNet等,如图2和图3所示
同样地,SE模块还可以应用在其他的网络结构,这里给出论文中的原表格,SE-ResNet-50和SE-ResNetXt-50的具体结构,见表格1
增加了SE模块后,模型的参数以及计算量都会相应的增加,这些增加的参数仅仅由门控门控机制的两个全连接层产生,因此只占网络容量的一小部分。具体的计算公式如公式5:
2 r ∑ s = 1 s N s ⋅ C s 2 \frac{2}{r}\sum_{s=1}^s N_s \cdot C_s^2 r2s=1∑sNs⋅Cs2
其中 r r r 表示的是降维系数, S S S 表示的是级数(the number of stages),一个级数指的是对公共空间维度的特征图进行操作的块的集合, C s C_s Cs 表示的输出通道的维度, N s N_s Ns 表示的级数 S S S 重复块的数量。
当 r = 16 r=16 r=16 时, SE-ResNet-50 只增加了约10%的参数量,但是计算量却增加不到1%
SE模型性能
SE模块可以很容易地引入到其他网络中,为了验证SE模块的效果,在主流的流行网络中引入了SE模块,对比其在ImageNet上的效果,如表2所示:
可以看到所有的网络在加入SE模块后分类准确度均有一定的提升,为了实际地体会SE模块,之后就是尝试仿真实现,更加深入的了解其网络架构和效果
SE模块仿真
以下代码参考的是github代码[4]
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
import time
device = ('cuda' if torch.cuda.is_available() else 'cpu')
device
'cpu'
# 超参数
EPOCHS = 40
BATCH_SIZE = 128
LEARNING_RATE = 1e-1
WEIGHT_DECAY = 1e-4
获取数据
使用torchvision.dataset
获取数据
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
Files already downloaded and verified
Files already downloaded and verified
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
定义SeNet模型
# Squeeze and Excitation Block Module
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.fc = nn.Sequential(
nn.Conv2d(channels, channels // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(channels // reduction, channels * 2, 1, bias=False),
)
def forward(self, x):
w = F.adaptive_avg_pool2d(x, 1) # Squeeze
w = self.fc(x)
w, b = w.split(w.data.size(1) // 2, dim=1) # Excitation
w = torch.sigmoid(w)
return x * w + b # Scale and add bias
# Residual Block with SEBlock
class ResBlock(nn.Module):
def __init__(self, channels):
super(ResBlock, self).__init__()
self.conv_lower = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.conv_upper = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels)
)
self.se_block = SEBlock(channels)
def forward(self, x):
path = self.conv_lower(x)
path = self.conv_upper(path)
path = self.se_block(path)
path = x + path
return F.relu(path)
# Network Module
class Network(nn.Module):
def __init__(self, in_channel, filters, blocks, num_classes):
super(Network, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channel, filters, 3, padding=1, bias=False),
nn.BatchNorm2d(filters),
nn.ReLU()
)
self.res_blocks = nn.Sequential(*[ResBlock(filters) for _ in range(blocks - 1)])
self.out_conv = nn.Sequential(
nn.Conv2d(filters, 128, 1, padding=0, bias=False),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.conv_block(x)
x = self.res_blocks(x)
x = self.out_conv(x)
x = F.adaptive_avg_pool2d(x, 1)
x = x.view(x.data.size(0), -1)
x = self.fc(x)
return F.log_softmax(x, dim=1)
训练模型
net = Network(3, 128, 10, 10).to(device)
ACE = nn.CrossEntropyLoss().to(device)
opt = optim.SGD(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, momentum=.9, nesterov=True)
for epoch in range(1, EPOCHS + 1):
print('[Epoch %d]' % epoch)
train_loss = 0
train_correct, train_total = 0, 0
start_point = time.time()
for inputs, labels in train_loader:
inputs, labels = Variable(inputs).to(device),Variable(labels).to(device)
opt.zero_grad()
preds = net(inputs)
loss = ACE(preds, labels)
loss.backward()
opt.step()
train_loss += loss.item()
train_correct += (preds.argmax(dim=1) == labels).sum().item()
train_total += len(preds)
print('train-acc : %.4f%% train-loss : %.5f' % (100 * train_correct / train_total, train_loss / len(train_loader)))
print('elapsed time: %ds' % (time.time() - start_point))
test_loss = 0
test_correct, test_total = 0, 0
for inputs, labels in test_loader:
with torch.no_grad():
inputs, labels = Variable(inputs).to(device), Variable(labels).to(device)
preds = net(inputs)
test_loss += ACE(preds, labels).item()
test_correct += (preds.argmax(dim=1) == labels).sum().item()
test_total += len(preds)
print('test-acc : %.4f%% test-loss : %.5f' % (100 * test_correct / test_total, test_loss / len(test_loader)))
torch.save(net.state_dict(), './data/checkpoint/checkpoint-%04d.bin' % epoch)
[Epoch 1]
train-acc : 62.9240% train-loss : 1.02725
elapsed time: 167s
test-acc : 59.9800% test-loss : 1.13711
[Epoch 2]
train-acc : 69.3160% train-loss : 0.85710
elapsed time: 170s
test-acc : 67.6300% test-loss : 0.92139
[Epoch 3]
train-acc : 73.9000% train-loss : 0.74356
elapsed time: 171s
test-acc : 70.7700% test-loss : 0.84002
[Epoch 4]
train-acc : 77.2340% train-loss : 0.65098
elapsed time: 171s
test-acc : 74.3400% test-loss : 0.75001
[Epoch 5]
train-acc : 79.7560% train-loss : 0.58424
elapsed time: 171s
test-acc : 74.8000% test-loss : 0.71813
[Epoch 6]
train-acc : 81.8820% train-loss : 0.52713
elapsed time: 171s
test-acc : 77.7400% test-loss : 0.66449
[Epoch 7]
train-acc : 83.0260% train-loss : 0.49098
elapsed time: 171s
test-acc : 79.3000% test-loss : 0.60599
[Epoch 8]
train-acc : 84.2880% train-loss : 0.45633
elapsed time: 171s
test-acc : 78.0500% test-loss : 0.64819
[Epoch 9]
train-acc : 85.2660% train-loss : 0.43147
elapsed time: 171s
test-acc : 80.7400% test-loss : 0.57734
[Epoch 10]
train-acc : 86.2080% train-loss : 0.39924
elapsed time: 171s
test-acc : 81.9000% test-loss : 0.53836
[Epoch 11]
train-acc : 86.9320% train-loss : 0.38040
elapsed time: 171s
test-acc : 82.7100% test-loss : 0.51160
[Epoch 12]
train-acc : 87.4740% train-loss : 0.36286
elapsed time: 170s
test-acc : 81.8500% test-loss : 0.54868
[Epoch 13]
train-acc : 88.1580% train-loss : 0.34673
elapsed time: 171s
test-acc : 83.0700% test-loss : 0.49779
[Epoch 14]
train-acc : 88.9260% train-loss : 0.31996
elapsed time: 171s
test-acc : 83.8900% test-loss : 0.48193
[Epoch 15]
train-acc : 89.1380% train-loss : 0.31583
elapsed time: 171s
test-acc : 83.9900% test-loss : 0.49245
[Epoch 16]
train-acc : 89.5460% train-loss : 0.30087
elapsed time: 170s
test-acc : 84.0100% test-loss : 0.49648
[Epoch 17]
train-acc : 90.0420% train-loss : 0.29067
elapsed time: 171s
test-acc : 85.2700% test-loss : 0.44473
[Epoch 18]
train-acc : 90.3720% train-loss : 0.28137
elapsed time: 171s
test-acc : 83.8900% test-loss : 0.49883
[Epoch 19]
train-acc : 90.6020% train-loss : 0.26961
elapsed time: 171s
test-acc : 84.4700% test-loss : 0.47203
[Epoch 20]
train-acc : 91.1460% train-loss : 0.25927
elapsed time: 170s
test-acc : 84.4200% test-loss : 0.49412
[Epoch 21]
train-acc : 91.1540% train-loss : 0.25661
elapsed time: 170s
test-acc : 85.3500% test-loss : 0.43626
[Epoch 22]
train-acc : 91.3620% train-loss : 0.24741
elapsed time: 171s
test-acc : 86.2200% test-loss : 0.41310
[Epoch 23]
train-acc : 91.9760% train-loss : 0.23271
elapsed time: 171s
test-acc : 86.5600% test-loss : 0.40795
[Epoch 24]
train-acc : 92.0000% train-loss : 0.23080
elapsed time: 171s
test-acc : 84.8000% test-loss : 0.46834
[Epoch 25]
train-acc : 92.1460% train-loss : 0.22744
elapsed time: 171s
test-acc : 85.4300% test-loss : 0.44402
[Epoch 26]
train-acc : 92.2120% train-loss : 0.22320
elapsed time: 170s
test-acc : 86.3300% test-loss : 0.41405
[Epoch 27]
train-acc : 92.3740% train-loss : 0.21625
elapsed time: 170s
test-acc : 87.3800% test-loss : 0.38440
[Epoch 28]
train-acc : 92.6960% train-loss : 0.21098
elapsed time: 171s
test-acc : 84.9300% test-loss : 0.46326
[Epoch 29]
train-acc : 92.8700% train-loss : 0.20541
elapsed time: 171s
test-acc : 86.5900% test-loss : 0.41840
[Epoch 30]
train-acc : 93.0700% train-loss : 0.20067
elapsed time: 170s
test-acc : 86.8400% test-loss : 0.42302
[Epoch 31]
train-acc : 93.2300% train-loss : 0.19319
elapsed time: 171s
test-acc : 87.1700% test-loss : 0.39542
[Epoch 32]
train-acc : 93.2280% train-loss : 0.19576
elapsed time: 171s
test-acc : 86.6500% test-loss : 0.43697
[Epoch 33]
train-acc : 93.5900% train-loss : 0.18686
elapsed time: 170s
test-acc : 86.8300% test-loss : 0.40863
[Epoch 34]
train-acc : 93.5820% train-loss : 0.18315
elapsed time: 170s
test-acc : 86.8200% test-loss : 0.42321
[Epoch 35]
train-acc : 93.6140% train-loss : 0.18232
elapsed time: 170s
test-acc : 86.1700% test-loss : 0.43491
[Epoch 36]
train-acc : 93.9620% train-loss : 0.17560
elapsed time: 170s
test-acc : 86.9100% test-loss : 0.41068
[Epoch 37]
train-acc : 93.9920% train-loss : 0.17193
elapsed time: 170s
test-acc : 87.0600% test-loss : 0.41822
[Epoch 38]
train-acc : 93.8620% train-loss : 0.17253
elapsed time: 170s
test-acc : 88.0500% test-loss : 0.38560
[Epoch 39]
train-acc : 94.2040% train-loss : 0.16850
elapsed time: 170s
test-acc : 86.7000% test-loss : 0.42949
[Epoch 40]
train-acc : 94.2940% train-loss : 0.16422
elapsed time: 170s
test-acc : 87.2100% test-loss : 0.39914
导入保存的模型
net.load_state_dict(torch.load('data\\checkpoint\\checkpoint-0040.bin', map_location=torch.device('cpu')))
net.eval()
Network(
(conv_block): Sequential(
(0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(res_blocks): Sequential(
(0): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(1): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(2): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(3): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(4): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(5): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(6): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(7): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(8): ResBlock(
(conv_lower): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv_upper): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(se_block): SEBlock(
(fc): Sequential(
(0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
)
(out_conv): Sequential(
(0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(fc): Linear(in_features=128, out_features=10, bias=True)
)
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
sns.set()
for images, labels in test_loader:
pred = torch.argmax(net(images), axis=1)
print('confusion_matrix: \n', confusion_matrix(pred, labels))
print('accuracy_score:', accuracy_score(pred, labels))
print('precision_score:', precision_score(pred, labels, average='micro'))
print('f1-score:', f1_score(pred, labels, average='micro'))
break
confusion_matrix:
[[11 0 0 0 0 0 0 0 1 0]
[ 0 9 0 0 0 0 0 0 0 0]
[ 0 0 10 0 1 2 0 0 0 0]
[ 0 0 1 11 0 0 0 1 0 0]
[ 0 0 0 1 9 0 0 0 0 0]
[ 0 0 0 1 0 7 1 0 0 0]
[ 0 0 0 0 0 0 18 0 0 0]
[ 1 0 0 2 0 0 0 12 0 0]
[ 1 0 0 0 0 0 0 0 16 0]
[ 0 1 0 0 0 0 0 0 0 11]]
accuracy_score: 0.890625
precision_score: 0.890625
f1-score: 0.890625
pred
tensor([3, 8, 8, 8, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 0, 6, 7, 0, 4, 9,
2, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 3, 4, 9, 9, 5, 4, 6, 5, 6, 0, 9, 4, 9,
7, 6, 9, 8, 7, 3, 8, 8, 7, 3, 2, 5, 7, 5, 6, 3, 6, 2, 1, 2, 7, 7, 2, 6,
8, 8, 0, 2, 9, 3, 7, 8, 8, 1, 1, 7, 2, 2, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,
8, 3, 1, 2, 8, 0, 8, 3])
conf_mat = confusion_matrix(labels, pred)
conf_mat
array([[11, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[ 0, 9, 0, 0, 0, 0, 0, 0, 0, 1],
[ 0, 0, 10, 1, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 11, 1, 1, 0, 2, 0, 0],
[ 0, 0, 1, 0, 9, 0, 0, 0, 0, 0],
[ 0, 0, 2, 0, 0, 7, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 1, 18, 0, 0, 0],
[ 0, 0, 0, 1, 0, 0, 0, 12, 0, 0],
[ 1, 0, 0, 0, 0, 0, 0, 0, 16, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 11]], dtype=int64)
df = pd.DataFrame(conf_mat, index=test_dataset.classes, columns=test_dataset.classes)
df
airplane | automobile | bird | cat | deer | dog | frog | horse | ship | truck | |
---|---|---|---|---|---|---|---|---|---|---|
airplane | 11 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 |
automobile | 0 | 9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
bird | 0 | 0 | 10 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
cat | 0 | 0 | 0 | 11 | 1 | 1 | 0 | 2 | 0 | 0 |
deer | 0 | 0 | 1 | 0 | 9 | 0 | 0 | 0 | 0 | 0 |
dog | 0 | 0 | 2 | 0 | 0 | 7 | 0 | 0 | 0 | 0 |
frog | 0 | 0 | 0 | 0 | 0 | 1 | 18 | 0 | 0 | 0 |
horse | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 12 | 0 | 0 |
ship | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 16 | 0 |
truck | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 11 |
# 绘制混淆矩阵图
plt.figure(figsize=(12, 12))
plt.rcParams['font.sans-serif']=['SimHei']
sns.heatmap(df, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix")
plt.ylabel("True Class")
plt.xlabel("Predicted Class")
plt.show()
参考文献
- [1] Hu J, Shen L, Sun G. Squeeze-and-excitation networks[J]. arXiv preprint arXiv:1709.01507, 2017, 7.
其他链接
文章代码地址:madao33/computer-vision-learning
个人博客:madao33 blog