网络介绍
GoogLeNet是2014年提出的,主要创新为其inception结构,该结构得名于同名电影《盗梦空间》(Inception),取名为GoogLeNet是向LeNet致敬。
inception块采用并行连接的方式,可以提取到不同尺度的信息,意味着特征更多,分类更准确,缓解了串行卷积层特征丢失的问题
卷积核大小采用1、3、5,主要是为了输出维度一致。设定卷积核移动步长stride=1之后,只要分别设定padding=0、1、2,那么卷积之后便可以得到相同尺寸的特征,然后这些特征就可以直接聚合在一起。1x1卷积起到了降维,减少计算量的作用,并在一定程度上缓解过拟合。
网络结构
GoogLeNet跟VGG一样,在主体卷积部分中使用5个模块,(代码中使用b1,b2…b5来表示)每个模块之间使用步幅为2的3×3最大池化层来减小输出高宽。
Inception块里有4条并行的线路。前3条线路使用窗口大小分别是1×1、3×3和5×5的卷积层来抽取不同空间尺寸下的信息,其中中间2个线路会对输入先做1×1卷积来减少输入通道数,以降低模型复杂度。第四条线路则使用3×3最大池化层,后接1×1卷积层来改变通道数。4条线路都使用了合适的填充来使输入与输出的高和宽一致。最后将每条线路的输出在通道维上连结,并输入接下来的层中去,下面为Inception结构代码:
class Inception(nn.Module):
def __init__(self, in_c, c1, c2, c3, c4):
super(Inception, self).__init__()
# 分支1,1*1
self.p1_1 = Conv(in_c, c1, kernel_size=1)
# 分支2,1*1+3*3
self.p2_1 = Conv(in_c, c2[0], kernel_size=1)
self.p2_2 = Conv(c2[0], c2[1], kernel_size=3, padding=1)
# 分支3,1*1+5*5
self.p3_1 = Conv(in_c, c3[0], kernel_size=1)
self.p3_2 = Conv(c3[0], c3[1], kernel_size=5, padding=2)
# 分支4,3*3最大池化+1*1卷积
self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.p4_2 = Conv(in_c, c4, kernel_size=1)
def forward(self, x):
p1 = self.p1_1(x)
p2 = self.p2_2(self.p2_1(x))
p3 = self.p3_2(self.p3_1(x))
p4 = self.p4_2(self.p4_1(x))
return torch.cat((p1, p2, p3, p4), dim=1)
网络搭建及测试
此网络模型大小仅有23M,相比VGG13,大小降低了8倍多!表现却比VGG好。带有两个辅助分类器的网络结构,也才40余M。
此网络的训练和预测评估代码和上一篇VGG基本相同,不同的是该网络训练了两个版本(是否有辅助分类器),在预测和评估时,如果模型训练使用了辅助分类器,需在加载网络时将aux_logits设置为True(默认为False)
测试结果:
使用带有辅助分类器的模型结构测试,模型对各类别的预测准确率和召回率有一点提升
不带辅助分类器训练的模型:
带有辅助分类器的模型:
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
# 包含卷积和激活的常规卷积块
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(Conv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
class Inception(nn.Module):
def __init__(self, in_c, c1, c2, c3, c4):
super(Inception, self).__init__()
# 分支1,1*1
self.p1_1 = Conv(in_c, c1, kernel_size=1)
# 分支2,1*1+3*3
self.p2_1 = Conv(in_c, c2[0], kernel_size=1)
self.p2_2 = Conv(c2[0], c2[1], kernel_size=3, padding=1)
# 分支3,1*1+5*5
self.p3_1 = Conv(in_c, c3[0], kernel_size=1)
self.p3_2 = Conv(c3[0], c3[1], kernel_size=5, padding=2)
# 分支4,3*3最大池化+1*1卷积
self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.p4_2 = Conv(in_c, c4, kernel_size=1)
def forward(self, x):
p1 = self.p1_1(x)
p2 = self.p2_2(self.p2_1(x))
p3 = self.p3_2(self.p3_1(x))
p4 = self.p4_2(self.p4_1(x))
return torch.cat((p1, p2, p3, p4), dim=1)
# 辅助分类器
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes):
super(InceptionAux, self).__init__()
self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
self.conv = Conv(in_channels, 128, kernel_size=1)
self.fc1 = nn.Linear(2048, 1024)
self.fc2 = nn.Linear(1024, num_classes)
def forward(self, x):
x = self.averagePool(x)
x = self.conv(x)
x = torch.flatten(x, 1)
x = F.dropout(x, 0.5, training=self.training)
x = F.relu(self.fc1(x), inplace=True)
x = F.dropout(x, 0.5, training=self.training)
x = self.fc2(x)
return x
class GoogLeNet(nn.Module):
def __init__(self, num_classes=5, aux_logits=False, init_weights=False):
super(GoogLeNet, self).__init__()
self.aux_logits = aux_logits
# 分五个模块搭建
self.b1 = nn.Sequential(
Conv(3, 64, kernel_size=7, stride=2, padding=3),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
)
self.b2 = nn.Sequential(
Conv(64, 64, kernel_size=1),
Conv(64, 192, kernel_size=3, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
)
self.b3 = nn.Sequential(
Inception(192, 64, (96, 128), (16, 32), 32),
Inception(256, 128, (128, 192), (32, 96), 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
)
# 由于第四部分有辅助分类器,因此将其结构拆开
self.inception4a = Inception(480, 192, (96, 208), (16, 48), 64)
self.inception4b = Inception(512, 160, (112, 224), (24, 64), 64)
self.inception4c = Inception(512, 128, (128, 256), (24, 64), 64)
self.inception4d = Inception(512, 112, (144, 288), (32, 64), 64)
self.inception4e = Inception(528, 256, (160, 320), (32, 128), 128)
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
self.b4 = nn.Sequential(
self.inception4a, self.inception4b, self.inception4c, self.inception4d, self.inception4e,
self.maxpool4
)
if self.aux_logits:
self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)
self.b5 = nn.Sequential(
Inception(832, 256, (160, 320), (32, 128), 128),
Inception(832, 384, (192, 384), (48, 128), 128),
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # 该方法可指定输出核尺寸
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(1024, num_classes)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.b1(x)
x = self.b2(x)
x = self.b3(x)
if self.aux_logits and self.training:
x = self.inception4a(x)
aux1 = self.aux1(x)
x = self.inception4b(x)
x = self.inception4c(x)
x = self.inception4d(x)
aux2 = self.aux2(x)
x = self.inception4e(x)
x = self.maxpool4(x)
else:
x = self.b4(x)
x = self.b5(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x)
if self.training and self.aux_logits:
return x, aux1, aux2
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
evaluation.py
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import argparse
import matplotlib.pyplot as plt
import numpy as np
from model import GoogLeNet
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default=r'D:\Codes\DLNet\datas\flower_data\val', help='test data source')
parser.add_argument('--weights', type=str, default='google.pt')
parser.add_argument('--device', type=str, default='cuda:0')
opt = parser.parse_args()
return opt
def evaluate_model(model, dataloader, device):
model.eval()
y_true = []
y_pred = []
print('start infer...')
with torch.no_grad():
for img, labels in dataloader:
img = img.to(device)
labels = labels.to(device)
outputs = model(img)
_, predicted = torch.max(outputs, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)
cm = confusion_matrix(y_true, y_pred)
print('infer end, start plot...')
return precision, recall, f1, cm
def generate_infos(precision, recall, f1, cm):
# 数据集类别
classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
# 保存结果到txt文件
with open('evaluation_results.txt', 'w') as f:
f.write("Class\t\tPrecision\t\tRecall\t\tF1 Score\n")
for i in range(len(classes)):
if len(classes[i]) > 6:
f.write("{}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\n".format(classes[i], precision[i], recall[i], f1[i]))
else:
f.write("{}\t\t{:.4f}\t\t{:.4f}\t\t{:.4f}\n".format(classes[i], precision[i], recall[i], f1[i]))
# 绘制混淆矩阵
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.colorbar() # 添加一个颜色条,用于表示混淆矩阵中不同数值对应的颜色
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# 在混淆矩阵方块中显示数字
thresh = cm.max() / 2.
for i in range(len(classes)):
for j in range(len(classes)):
r = cm[i, j] / cm[i, :].sum()
plt.text(j, i, '{:.2f}'.format(r),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.savefig('confusion_matrix.jpg')
def main():
opt = get_opt()
# 数据预处理
transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
# 生成数据加载器
test_dataset = ImageFolder(opt.data, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)
# 加载模型
model = GoogLeNet(num_classes=5, aux_logits=True).to(opt.device)
model.load_state_dict(torch.load(opt.weights))
# 获取评估结果
precision, recall, f1, cm = evaluate_model(model, test_dataloader, opt.device)
generate_infos(precision, recall, f1, cm)
if __name__ == '__main__':
main()