此篇文章使用的数据和数据划分和第一篇AlexNet相同,且predict.py和train.py都和Alexnet的一样,唯一不同的是加载网络时候的model指定(model=AlexNet -> model=VGG13)。此外,本篇文章新增了模型评估脚本,运行该脚本会保存各类别的精确率、召回率、F1 Score信息,并生成混淆矩阵,以便后续与其它网络对比。
训练时发现,VGG训练明显比AlexNet更慢,且同样的batchsize,VGG占用显存更大,同时,训练出的模型有247M, 比AlexNet大得多,但是,VGG网络相较于AlexNet,在花分类数据集上,有多达10%以上的mAP提升。
一、环境信息:
windows11
torch 2.1.0
cuda 121
python 3.9
要用其它Python 、cuda、torch 1.x版本也是可以的
二、数据集下载
开始搭建网络前先下载数据。
本网络使用花分类数据集进行训练,下载地址: https://pan.baidu.com/s/1pBh6tqnp7qtdd1WfjViy-Q 提取码: dj8v
解压后会得到如下文件:
-flower_photos
-----daisy
-----dandelion
-----roses
-----sunflowers
-----tulips
三、网络介绍
VGG论文链接:1409.1556.pdf] (arxiv.org)
虽然AlexNet指明了深度卷积神经网络可以取得出色的结果,但并没有提供简单的规则以指导后来的研究者如何设计新的网络。
VGG的名字来源于论文作者所在的实验室Visual Geometry Group。VGG提出了可以通过重复使用简单的基础块来构建深度模型的思路。
AlexNet只有5个卷积层,而VGG的卷积层数量是其数倍(VGG通过5个可以重复使用的卷积块来构造网络),可以提取到更为抽象且丰富的特征。与AlexNet不同的是,VGG使用的卷积核大小为3 X 3。2个3 X 3卷积核可以替代1个5X5卷积核,它们的感受野相同,但3X3参数量更小。
与第一篇文章AlexNet代码结构不同的是,VGG13的训练使用了权重初始化,此方法可以使梯度下降更有效,使网络更好的收敛。
四、网络搭建及测试
model.py
import torch
import torch.nn as nn
class VGG13(nn.Module):
def __init__(self, num_classes=5, init_weights=False):
super(VGG13, self).__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1), # 224x224
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1), # 224x224
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # 112x112
nn.Conv2d(64, 128, kernel_size=3, padding=1), # 112x112
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1), # 112x112
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # 56x56
nn.Conv2d(128, 256, kernel_size=3, padding=1), # 56x56
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # 56x56
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # 28x28
nn.Conv2d(256, 512, kernel_size=3, padding=1), # 28x28
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 28x28
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # 14x14
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 14x14
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 14x14
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # 7x7
)
self.head = nn.Sequential(
nn.Linear(512*7*7, 2048),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(2048, num_classes)
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.backbone(x)
x = torch.flatten(x, start_dim=1)
x = self.head(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
if __name__ == '__main__':
x = torch.randn((1,3,224,224))
model = VGG13(init_weights=True)
y = model(x)
print(y)
与原文不同的是,全连接层的参数数量从4096变为了2048,降低模型参数量,模型由原来的510M->247M,同时,训练时显存占用降低了400M
train.py
from model import VGG13
from torchvision import transforms, datasets, utils
from tqdm import tqdm
import torch.nn as nn
import torch
import argparse
import sys
def run(opt):
# 数据读取
data_transfom = {
'train': transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
'val': transforms.Compose([transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
train_dataset = datasets.ImageFolder(root=opt.data + '/train', transform=data_transfom['train'])
val_dataset = datasets.ImageFolder(root=opt.data + '/val', transform=data_transfom['val'])
train_data_num = len(train_dataset)
val_data_num = len(val_dataset)
train_data_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=8)
val_data_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=opt.batch_size,
shuffle=False,
num_workers=8)
print(f'==> train data number: {train_data_num}, value data number: {val_data_num}')
# 构建网络
device = opt.device
model = VGG13(num_classes=5, init_weights=True)
model.to(device)
print(f'==> use {device}')
save_path = './vgg13_best.pt'
loss_function = nn.CrossEntropyLoss()
running_loss = 0.0
best_acc = 0.0
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
for epoch in range(opt.epochs):
model.train()
train_bar = tqdm(train_data_loader, file=sys.stdout)
for step, (images, labels) in enumerate(train_bar):
optimizer.zero_grad()
outputs = model(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, opt.epochs, loss)
# 模型验证
model.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(val_data_loader, file=sys.stdout)
for val_data in val_bar:
val_imgs, val_labels = val_data
outputs = model(val_imgs.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accurate = acc / val_data_num
print('[epoch %d] val_accuracy: %.3f' % (epoch + 1, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(model.state_dict(), save_path)
print('==> Finished Training!')
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='../datas/flower_data', help='train data source')
parser.add_argument('--device', type=str, default='cuda:0', help='training device')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=50, help='total training epochs')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
opt = get_opt()
run(opt)
predict.py
import torch
from torchvision import transforms
import argparse
from PIL import Image
import matplotlib.pyplot as plt
from model import VGG13
def run(opt):
device = opt.device if torch.cuda.is_available() else 'cpu'
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img0 = Image.open(opt.img)
img = data_transform(img0)
img = torch.unsqueeze(img, dim=0).to(device) # 添加一个维度
# 加载模型
model = VGG13(num_classes=5).to(device)
model.load_state_dict(torch.load(opt.weights))
model.eval() # 评估模式,只正向传播
with torch.no_grad():
output = torch.squeeze(model(img)).cpu()
output = torch.softmax(output, dim=0)
conf = torch.max(output).numpy()
cls = torch.argmax(output).numpy()
classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
print_res = 'class: {} confidence: {}'.format(classes[cls], conf)
print(print_res)
plt.imshow(img0)
plt.title(print_res)
plt.show()
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--img', type=str, default=r'D:\Codes\DLNet\tulip1.jpeg', help='train data source')
parser.add_argument('--device', type=str, default='cuda:0', help='training device')
parser.add_argument('--weights', type=str, default='vgg_best.pt', help='training device')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
opt = get_opt()
run(opt)
evaluation.py
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import argparse
import matplotlib.pyplot as plt
import numpy as np
from model import VGG13
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default=r'D:\Codes\DLNet\datas\flower_data\val', help='test data source')
parser.add_argument('--weights', type=str, default='vgg_best.pt')
parser.add_argument('--device', type=str, default='cuda:0')
opt = parser.parse_args()
return opt
def evaluate_model(model, dataloader, device):
model.eval()
y_true = []
y_pred = []
print('start infer...')
with torch.no_grad():
for img, labels in dataloader:
img = img.to(device)
labels = labels.to(device)
outputs = model(img)
_, predicted = torch.max(outputs, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)
cm = confusion_matrix(y_true, y_pred)
print('infer end, start plot...')
return precision, recall, f1, cm
def generate_infos(precision, recall, f1, cm):
# 数据集类别
classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
# 保存结果到txt文件
with open('evaluation_results.txt', 'w') as f:
f.write("Class\t\tPrecision\t\tRecall\t\tF1 Score\n")
for i in range(len(classes)):
if len(classes[i]) > 6:
f.write("{}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\n".format(classes[i], precision[i], recall[i], f1[i]))
else:
f.write("{}\t\t{:.4f}\t\t{:.4f}\t\t{:.4f}\n".format(classes[i], precision[i], recall[i], f1[i]))
# 绘制混淆矩阵
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.colorbar() # 添加一个颜色条,用于表示混淆矩阵中不同数值对应的颜色
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# 在混淆矩阵方块中显示数字
thresh = cm.max() / 2.
for i in range(len(classes)):
for j in range(len(classes)):
r = cm[i, j] / cm[i, :].sum()
plt.text(j, i, '{:.2f}'.format(r),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.savefig('confusion_matrix.jpg')
def main():
opt = get_opt()
# 数据预处理
transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
# 生成数据加载器
test_dataset = ImageFolder(opt.data, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)
# 加载模型
model = VGG13().to(opt.device)
model.load_state_dict(torch.load(opt.weights))
# 获取评估结果
precision, recall, f1, cm = evaluate_model(model, test_dataloader, opt.device)
generate_infos(precision, recall, f1, cm)
if __name__ == '__main__':
main()
五、模型评估结果
运行evaluation.py时,需要指定评估数据路径、被评估的模型文件、使用的设备,例如在终端中输入:
python evaluation.py --data /xxx/xxx --weights vgg_best.pt --device 0
运行后,可以看到如下结果: