第一步:准备数据
10种食物数据:["冰激凌", "鸡蛋布丁", "烤冷面", "芒果班戟", "三明治", "松鼠鱼", "甜甜圈", "土豆泥", "小米粥", "玉米饼"],训练数据集总共有4050张图片,测试数据集总共有500张图片,验证数据集总共有450张图片

第二步:搭建模型
本文选择resnet34,efficientnet_b0,convnext_tiny,其网络结构分别如下:



由于是十分类问题,直接套用网络肯定是不行,因此会在全连接部分做手脚,参考代码如下:
# model = create_model(num_classes=args.num_classes).to(device)
# model = models.resnet34(pretrained=False)
# num_ftrs = model.fc.in_features # 获取全连接层的输入特征数
# model.fc = nn.Linear(num_ftrs, args.num_classes) # 修改全连接层以适应10个类别的输出
# model = model.to(device)
model = models.efficientnet_b0(pretrained=False)
num_ftrs = model.classifier[1].in_features # 获取原始全连接层的输入特征数
model.classifier = nn.Sequential(
nn.Linear(num_ftrs, args.num_classes) # 修改最后一层为10个输出类别的全连接层
)
model = model.to(device)
第三步:训练代码
1)损失函数为:交叉熵损失函数
2)网络可以从头训练或者利用预训练模型进行训练:
for epoch in range(args.epochs):
# 计时器time_start
time_start = time.time()
# train
train_loss, train_acc = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch,
lr_scheduler=lr_scheduler)
# validate
val_loss, val_acc = evaluate(model=model,
data_loader=val_loader,
device=device,
epoch=epoch)
time_end = time.time()
f.write("[epoch {}] train_loss: {:.3f},train_acc:{:.3f},val_loss:{:.3f},val_acc:{:.3f},Spend_time:{:.3f}S"
.format(epoch + 1, train_loss, train_acc, val_loss, val_acc, time_end - time_start))
f.flush()
# add Training results into tensorboard
# #######################################
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
tb_writer.add_scalar(tags[0], train_loss, epoch)
tb_writer.add_scalar(tags[1], train_acc, epoch)
tb_writer.add_scalar(tags[2], val_loss, epoch)
tb_writer.add_scalar(tags[3], val_acc, epoch)
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
# add figure into tensorboard
# #######################################
fig = plot_class_preds(net=model,
images_dir=r"plot_img",
transform=data_transform["val"],
num_plot=6,
device=device)
if fig is not None:
tb_writer.add_figure("predictions vs. actuals",
figure=fig,
global_step=epoch)
if val_acc > best_acc:
best_acc = val_acc
f.write(',save best model')
torch.save(model.state_dict(), save_path + "/weights/bestmodel.pth")
f.write('\n')
f.close()
第四步:统计正确率

第五步:界面代码运行


第六步:整个工程的内容

项目完整文件下载请见演示与介绍视频的简介处给出:➷➷➷
https://www.bilibili.com/video/BV1PbkDBREW3/

554

被折叠的 条评论
为什么被折叠?



