1、数据集
春天来了,我在公园的小道漫步,看着公园遍野的花朵,看起来真让人心旷神怡,一周工作带来的疲惫感顿时一扫而光。难得一个糙汉子有闲情逸致俯身欣赏这些花朵儿,然而令人尴尬的是,我一朵都也不认识。那时我在想,如果有一个花卉识别软件,可以用手机拍一下就知道这是一种什么花朵儿,那就再好不过了。我不知道市场上是否有这样一种软件,但是作为一个从事深度学习的工程师,我马上知道了怎么做,最关键的不是怎么做,而是数据采集。住所附近就是大型公园,一年司机繁花似锦,得益于此,我可以在闲暇时间里采集到大量的花卉数据。本数据集由本人亲自使用手机进行拍摄采集,原始数据集包含了27万张图片,图片的尺寸为1024x1024,为了方便储存和传输,把原图缩小为224x224。采集数据是一个漫长的过程,因此数据集的发布采用分批发布的形式,也就是每采集够16种花卉,就发布一次数据集。每种花卉的图片数量约为2000张,每次发布的数据集的图片数量约为32000张,每次发布的数据集包含的花卉种类都不一样。目前花卉的种类只有48种,分为三批发布,不过随着时间的推移,采集到的花卉越来越多。这里就把数据集分享出来,供各位人工智能算法研究者使用。以下是花卉数据集的简要介绍和下载地址。
(1)花卉数据集01(数据集+训练代码下载地址)
花卉数据集01,采集自2022年,一共16种花卉,数据集大小为32000张,图片大小为224x224的彩色图像。数据集包含的花卉名称为:一年蓬,三叶草,三角梅,两色金鸡菊,全叶马兰,全缘金光菊,剑叶金鸡菊,婆婆纳,油菜花,滨菊,石龙芮,绣球小冠花,蒲公英,蓝蓟,诸葛菜,鬼针草。数据集的缩略图如下:
(2)花卉数据集02(数据集+训练源码下载地址)
花卉数据集02,采集与2023年,一共16种花卉,每种花卉约2000张,总共32000,图片大小为224x224。数据集包含的花卉有:千屈菜,射干,旋覆花,曼陀罗,桔梗,棣棠,狗尾草,狼尾草,石竹,秋英,粉黛乱子草,红花酢浆草,芒草,蒲苇,马鞭草,黄金菊。数据集缩略图如下:
(3)花卉数据集3(数据集+训练源码下载地址)
花卉数据集03,采集与2023年,一共16种花卉,每种花卉约2000张,总共32000,图片大小为224x224。数据集包含的花卉有:北香花介,大花耧斗菜,小果蔷薇,小苜蓿,小蜡,泽珍珠菜,玫瑰,粉花绣线菊,线叶蓟,美丽月见草,美丽芍药,草甸鼠尾草,蓝花鼠尾草,蛇莓,长柔毛野豌豆,高羊茅。数据集缩略图如下:
2、图片分类模型
为了研究不同图片分类模型对于花朵的分类效果,以及图片分类模型在不同硬件平台的推理速度,这里分别使用目前主流的22种图片分类模型进行训练,并在cpu平台和GPU平台进行部署测试。这些模型是如下:
- resnet系列:resnet18、resnet34、resnet50、resnet101、resnet152。
- vgg系列:vgg11、vgg13、vgg6、vgg19。
- squeezenet系列:squeezenet_v1、squeezenet_v2、squeezenet_v3。
- mobilenet系列:mobilenet_v1、mobilenet_v2。
- inception系列:inception_v1、inception_v2、inception_v3。
- 其他系列:alexnet、lenet、mnist、tsl16、zfnet。
以上模型的训练代码基于pytorch架构,内置集成22了种模型,可进行傻瓜式训练。以下的代码块为训练代码的主脚本,完整的训练代码以及数据集请在此链接下载:源码下载链接。
import torch
import torch.nn as nn
import torch.optim as optim
from utils.dataloader import CustomImageDataset
from torch.utils.data import DataLoader
from utils.build_model import build_model
import argparse
import time
import os
if __name__ == '__main__':
parser =argparse.ArgumentParser(description='图片分类模型训练')
parser.add_argument('-input_shape', type=tuple,default=(3,224,224),help='模型输入的通道数、高度、宽度')
parser.add_argument('-train_imgs_dir', type=str,default='dataset/train',help='训练集目录')
parser.add_argument('-test_imgs_dir', type=str,default='dataset/test',help='测试集目录')
parser.add_argument('-classes_file', type=str,default='dataset/classes.txt',help='类别文件')
parser.add_argument('-epochs', type=int,default=50,help='迭代次数')
parser.add_argument('-batch_size', type=int,default=64,help='批大小,根据显存大小调整')
parser.add_argument('-init_weights', type=str,default="init_weights/squeezenet_v1.pth",help='用于初始化的权重,请确保初始化的权重和训练的模型相匹配')
parser.add_argument('-optim', type=str,default="adam",help='优化器选择,可选sgd或者adam. sgd优化器训练效果较好,但参数比较难调节,不好收敛')
parser.add_argument('-lr', type=float,default=0.0001,
help='初始学习率,此参数对模型训练影响较大,如果选择不合适,模型甚至不收敛.\
如果遇到模型训练不收敛(损失函数不下降,准确度很低),可以尝试调整学习率.\
resnet系列推荐优化器选择sgd,学习率设0.001;vgg系列优化器推荐adam,学习率为0.0001,其他模型优化器选择adam,推荐学习率为0.0002')
parser.add_argument('-model_name', type=str,default='squeezenet_v1',
help='模型名称,可选resnet18/resnet34/resnet50/resnet101/resnet152\
/alexnet/lenet/zfnet/tsl16/mnist\
vgg11/vgg13/vgg16,vgg19\
squeezenet_v1/squeezenet_v2/squeezenet_v3\
inception_v1/inception_v2/inception_v3\
mobilenet_v1/mobilenet_v2/\
')
parser.add_argument('-argument', type=bool,default=True,help='是否在训练时开启数据增强模式')
args = parser.parse_args()
print("模型:",args.model_name)
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
classes=[]
try:
with open(args.classes_file,"rt",encoding="ANSI")as f:
for line in f:
classes.append(line.strip())
except:
with open(args.classes_file,"rt",encoding="UTF-8")as f:
for line in f:
classes.append(line.strip())
num_class=len(classes)
model=build_model(args.model_name,args.input_shape,num_class)
if os.path.exists(args.init_weights):
try:
model.load_state_dict(args.init_weights)
except:
model.weights_init()
print("参数初始化失败!请确保初始化参数与模型相一致.")
else:
model.weights_init()
print("没有找到名称为%s的权重文件,模型将跳过参数初始化"%(args.init_weights))
model=model.to(device)
# Create data loaders.
training_data=CustomImageDataset(args.train_imgs_dir,classes,args.argument)
test_data=CustomImageDataset(args.test_imgs_dir,classes,False)
train_dataloader = DataLoader(training_data, batch_size=args.batch_size,shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=args.batch_size)
loss_fn=nn.CrossEntropyLoss()
if args.optim=="adam":
optimizer=optim.Adam(model.parameters(), lr=args.lr)
else:
optimizer=optim.SGD(model.parameters(), lr=args.lr,momentum=0.9,weight_decay=0.0005)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,args.epochs)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
start=time.time()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
optimizer.zero_grad()
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
scheduler.step()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
end=time.time()
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}] time: {end-start:>3f}s")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
start=time.time()
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
end=time.time()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} , Time: {end-start:>3f}s\n")
return correct
best_accuracy=0
for t in range(args.epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
current_accuracy=test(test_dataloader, model, loss_fn)
if current_accuracy>best_accuracy:
best_accuracy=current_accuracy
torch.save(model,"weights/%s_best_accuracy.pth"%(args.model_name))
torch.save(model,"weights/%s_last_accuracy_%.2f.pth"%(args.model_name,current_accuracy))
print("Done!")
3、图片分类模型评估
分别训练了22种模型,图片为224x224的RGB图像。将约96000张图片划分为训练集和测试集,其中测试集占10%,一共9600张,训练集90%,一共86400张。训练充分后,对各种模型的top1、top2、top3、top4、top5准确度进行评估,并分别在cpu平台(intel i9)和gpu平台(RTX 3090)进行推理速度的测试。模型性能评估以及推理速度测试结果如表1所示。
表1 模型性能评估以及推理速度测试结果
模型 | 参数量 [M] | 计算量 [G] | GPU速度[FPS] | CPU速度[FPS] | Top1准确度[%] | Top2准确度[%] | Top3准确度[%] | Top4准确度[%] | Top5准确度[%] |
---|---|---|---|---|---|---|---|---|---|
resnet18 | 11.19 | 3.64 | 557.30 | 204.05 | 98.84 | 99.70 | 99.86 | 99.93 | 99.96 |
resnet34 | 21.30 | 7.35 | 347.39 | 104.35 | 98.64 | 99.74 | 99.86 | 99.93 | 99.96 |
resnet50 | 34.94 | 10.31 | 295.24 | 68.19 | 98.67 | 99.66 | 99.86 | 99.94 | 99.97 |
resnet101 | 53.90 | 17.77 | 171.84 | 41.27 | 98.61 | 99.64 | 99.83 | 99.90 | 99.94 |
resnet152 | 68.41 | 23.46 | 130.21 | 31.20 | 98.44 | 99.58 | 99.86 | 99.92 | 99.96 |
vgg11 | 128.96 | 15.25 | 462.24 | 30.00 | 92.88 | 97.35 | 98.69 | 99.26 | 99.48 |
vgg13 | 129.15 | 22.67 | 411.55 | 22.27 | 95.18 | 98.22 | 99.21 | 99.55 | 99.73 |
vgg16 | 134.46 | 30.99 | 340.34 | 20.14 | 95.35 | 98.48 | 99.29 | 99.50 | 99.60 |
vgg19 | 139.77 | 39.33 | 292.51 | 16.71 | 94.89 | 98.25 | 99.01 | 99.39 | 99.62 |
mobilenet_v1 | 3.25 | 1.16 | 942.09 | 506.42 | 97.45 | 99.44 | 99.71 | 99.83 | 99.90 |
mobilenet_v2 | 4.03 | 0.91 | 489.69 | 386.20 | 95.99 | 98.98 | 99.52 | 99.75 | 99.81 |
inception_v1 | 6.02 | 3.20 | 343.56 | 203.43 | 95.80 | 98.79 | 99.44 | 99.74 | 99.83 |
inception_v2 | 7.85 | 3.34 | 291.10 | 165.49 | 98.30 | 99.54 | 99.80 | 99.85 | 99.90 |
inception_v3 | 21.87 | 7.65 | 136.25 | 71.89 | 99.05 | 99.81 | 99.92 | 99.95 | 99.97 |
squeezenet_v1 | 0.76 | 1.61 | 758.27 | 362.22 | 97.44 | 99.36 | 99.69 | 99.81 | 99.86 |
squeezenet_v2 | 0.76 | 1.61 | 704.60 | 360.75 | 97.27 | 99.23 | 99.67 | 99.80 | 99.85 |
squeezenet_v3 | 1.10 | 2.37 | 658.07 | 267.28 | 98.28 | 99.54 | 99.77 | 99.85 | 99.89 |
mnist_net | 214.45 | 51.37 | 189.81 | 10.13 | 89.47 | 96.20 | 98.14 | 98.91 | 99.28 |
AlexNet | 17.69 | 2.35 | 858.92 | 211.73 | 96.20 | 98.79 | 99.49 | 99.69 | 99.76 |
LeNet | 78.45 | 0.94 | 1041.75 | 76.12 | 84.86 | 93.75 | 96.55 | 97.94 | 98.70 |
TSL16 | 116.95 | 23.56 | 381.63 | 24.09 | 95.61 | 98.40 | 99.15 | 99.53 | 99.69 |
ZF_Net | 72.09 | 2.68 | 351.87 | 82.68 | 96.58 | 99.03 | 99.47 | 99.68 | 99.80 |
从表一展示的结果来看,面对48种花卉的分类任务:如果只关心Top5分类准确度,那么这些模型均能达到98%以上的分类准确度,大部分模型的Top5准确度都能达到99%以上,对于实际应用而言,花卉分类程序通常会给出5个备选项,这样的话,只要5个备选项里边存在一个正确选项,就可以认为花卉分类是成功的。当然,如果追求单一选项的准确性,resnet系列模型、inception系列、squeezenet系列模型,在Top1分类准确度上表现不俗,可以达到97%以上的准确度。通常来说,可以工程化的图片分类模型,不仅仅要求其具备良好的分类准确度,还对其推理速度有一定的要求。表1的推理速度测试数据,分别在Intel i9 CPU平台和英伟达RTX 3090 GPU平台进行测试,推理用的软件接口是onnx推理架构,测试的策略是逐一对1000张224x224的彩色图片输送到模型中进行推理,统计其总的推理时间,然后计算平均推理帧率。从表1的数据可以得知,GPU平台的推理速度要比CPU平台的推理的速度快很多,而且在GPU平台推理帧率高的模型,在CPU平台的推理帧率未必高,反之亦然,也就是说,模型推理帧率的排名,跟硬件平台是有关的。在GPU平台上推理帧率比较靠前的模型是squeezenet系列、mobilenet系列,以及alexnet和lenet;在CPU平台上推理帧率比较靠前的模型是squeezenet系列、mobilenet系列、alexnet、inception_v1、resnet18。综合来说,squeezenet系列、mobilenet系列得分是最高的,因为他们在分类准确度上表现优秀,并且在GPU和CPU平台上的推理帧率都变现不错,而且模型的参数量很小,适合在线部署和嵌入式部署,所以这些模型应当优先选择。另外,从评估和测试的结果来看,还可以得到以下几个结论:
- 对于可以并行计算的硬件平台来说,比如GPU、NPU以及一些具有批处理能力的CPU,模型的推理帧率跟模型的参数量和计算量没有绝对的关联性,更多的是跟模型的结构有关,如果模型适合于并行计算,那么即使模型具有较大的计算量,其推理速度也可以很快;
- 对于串行执行的计算硬件来说,比如常规指令的CPU,模型的推理速度跟模型的计算量的是线性相关的,也就是说计算量越大,推理帧率越低;
- 适合GPU平台部署的模型未必适合在CPU平台上部署,所以模型的选择要根据最终的部署平台而定。
4、总结
花卉数据集共包括96000张图片,囊括了48种花卉的类别,其中10%为测试集,90%为训练集。图片的大小为224x224,通道数为3。一共使用了22种模型进行训练,通过模型评估和硬件平台部署测试得出结论:squeezenet_v1、squeezenet_v2、squeezenet_v3、mobilenet_v1、mobilenet_v2几个模型,具有参数量小,计算量小,分类准确度高的优点,并且在GPU平台和CPU平台上推理速度较快,适合在各种平台上部署,特别是适合移动端和嵌入式的部署。