项目设计
- 设计爬虫程序获取中草药图片构建数据集,采用多种数据增广和图像增强方法提升训练集质量;
- 通过实验比较并调整参数,选定ResNeSt50模型;采用CosineAnnealingLR学习率调整策略减少模型参数陷入局部最优解的概率,调整损失函数为Focal Loss + Contrastive Loss,修改激活函数为LeakyReLU及其他参数提升模型训练效果;
- 设计网站,用户可通过Web前端上传中草药图像,后端利用深度学习模型进行识别,并实时保存数据库数据。
项目展示
实现功能
用户可通过 Web 前端上传中草药图像;前端将图像发送至后端,后端通过深度学习模型进行识别;随后,后端将预测结果(包括中草药种类及相关信息)返回给前端,供用户查看。整个过程中,数据库数据将被实时调用和保存。系统现在可以识别18种中草药植株,在测试集上的准确率达到97.8%。
功能展示
打开网站,点击选择需要识别的中草药图片。
点击提交,返回识别结果。
后端运行过程:
源代码展示
图像增广函数
def process_images(folder_path):
# 遍历文件夹中的所有文件和子文件夹
for root, dirs, files in os.walk(folder_path):
# 只处理二级文件夹
if len(root.split(os.sep)) - len(folder_path.split(os.sep)) == 1:
# 遍历文件夹中的所有文件
for file in files:
# 检查文件是否为图片(根据文件扩展名)
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
# 构建图片文件的完整路径
image_path = os.path.join(root, file)
# 打开图片文件
image = Image.open(image_path)
# 对图片进行操作(这里需要根据具体需求编写代码)
# 镜像
image_mirror = image.transpose(Image.FLIP_LEFT_RIGHT)
image_mirror.save(os.path.join(root, 'mirror_' + file))
# 缩放
image_scale = image.resize((image.width // 2, image.height // 2))
image_scale.save(os.path.join(root, 'scale_' + file))
# 平移
image_shift = ImageChops.offset(image, 50, 50)
image_shift.save(os.path.join(root, 'shift_' + file))
# 左旋
image_rotate_left = image.rotate(90)
image_rotate_left.save(os.path.join(root, 'rotate_left_' + file))
# 右旋
image_rotate_right = image.rotate(-90)
image_rotate_right.save(os.path.join(root, 'rotate_right_' + file))
# 裁剪
image_crop = image.crop((100, 100, 300, 300))
image_crop.save(os.path.join(root, 'crop_' + file))
# 亮度和对比度调整
enhancer = ImageEnhance.Brightness(image)
image_enhanced = enhancer.enhance(1.5)
enhancer = ImageEnhance.Contrast(image_enhanced)
image_enhanced = enhancer.enhance(1.5)
image_enhanced.save(os.path.join(root, 'enhanced_' + file))
# 关闭图片文件
image.close()
训练代码
import time
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
import torch.optim as optim
from resnest.torch import resnest50
import warnings
warnings.filterwarnings("ignore")
# 获取计算硬件
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
from torchvision import transforms
# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 数据集文件夹路径
dataset_dir = 'dataset_split'
train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)
# 载入训练集和测试集
from torchvision import datasets
train_dataset = datasets.ImageFolder(train_path, train_transform)
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)
from torch.utils.data import DataLoader
BATCH_SIZE = 128
# 训练集的数据加载器
train_loader = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=16
)
# 测试集的数据加载器
test_loader = DataLoader(test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=16
)
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler
# 加载 resnest50 模型
model = resnest50(pretrained=False)
# 替换最后一个全连接层之前的ReLU层
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, n_class),
nn.LeakyReLU(negative_slope=0.2, inplace=True) # 使用LeakyReLU
)
optimizer = optim.Adam(model.parameters())
model = model.to(device)
# 定义 Focal Loss 和 Contrative Loss 类
class FocalContrastiveLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, margin=1):
super(FocalContrastiveLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.margin = margin
def forward(self, inputs, targets):
# Focal Loss
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
# Contrastive Loss
num_classes = inputs.size(1)
one_hot = torch.zeros_like(inputs)
one_hot.scatter_(1, targets.view(-1, 1), 1)
similarities = F.cosine_similarity(inputs, one_hot * self.margin, dim=1)
contrastive_loss = (1 - similarities).clamp(min=0).mean()
# Total loss
loss = focal_loss + contrastive_loss
return loss.mean()
# 定义新的损失函数
criterion = FocalContrastiveLoss()
# 训练轮次 Epoch
EPOCHS = 300
# 学习率降低策略
lr_scheduler=lr_scheduler.CosineAnnealingLR(optimizer,T_max=10,eta_min=0.01)
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
def train_one_batch(images, labels):
'''
运行一个 batch 的训练,返回当前 batch 的训练日志
'''
# 获得一个 batch 的数据和标注
images = images.to(device)
labels = labels.to(device)
outputs = model(images) # 输入模型,执行前向预测
loss = criterion(outputs, labels) # 计算当前 batch 中,每个样本的平均交叉熵损失函数值
# 优化更新权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 获取当前 batch 的标签类别和预测类别
_, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别
preds = preds.cpu().numpy()
loss = loss.detach().cpu().numpy()
outputs = outputs.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
log_train = {}
log_train['epoch'] = epoch
log_train['batch'] = batch_idx
# 计算分类评估指标
log_train['train_loss'] = loss
log_train['train_accuracy'] = accuracy_score(labels, preds)
# log_train['train_precision'] = precision_score(labels, preds, average='macro')
# log_train['train_recall'] = recall_score(labels, preds, average='macro')
# log_train['train_f1-score'] = f1_score(labels, preds, average='macro')
return log_train
def evaluate_testset():
'''
在整个测试集上评估,返回分类评估指标日志
'''
loss_list = []
labels_list = []
preds_list = []
with torch.no_grad():
for images, labels in test_loader: # 生成一个 batch 的数据和标注
images = images.to(device)
labels = labels.to(device)
outputs = model(images) # 输入模型,执行前向预测
# 获取整个测试集的标签类别和预测类别
_, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别
preds = preds.cpu().numpy()
loss = criterion(outputs, labels) # 由 logit,计算当前 batch 中,每个样本的平均交叉熵损失函数值
loss = loss.detach().cpu().numpy()
outputs = outputs.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
loss_list.append(loss)
labels_list.extend(labels)
preds_list.extend(preds)
log_test = {}
log_test['epoch'] = epoch
# 计算分类评估指标
log_test['test_loss'] = np.mean(loss_list)
log_test['test_accuracy'] = accuracy_score(labels_list, preds_list)
log_test['test_precision'] = precision_score(labels_list, preds_list, average='macro')
log_test['test_recall'] = recall_score(labels_list, preds_list, average='macro')
log_test['test_f1-score'] = f1_score(labels_list, preds_list, average='macro')
return log_test
epoch = 0
batch_idx = 0
best_test_accuracy = 0
# 训练日志-训练集
df_train_log = pd.DataFrame()
log_train = {}
log_train['epoch'] = 0
log_train['batch'] = 0
images, labels = next(iter(train_loader))
log_train.update(train_one_batch(images, labels))
df_train_log = df_train_log._append(log_train, ignore_index=True)
# 训练日志-测试集
df_test_log = pd.DataFrame()
log_test = {}
log_test['epoch'] = 0
log_test.update(evaluate_testset())
df_test_log = df_test_log._append(log_test, ignore_index=True)
for epoch in range(1, EPOCHS+1):
print(f'Epoch {epoch}/{EPOCHS}')
## 训练阶段
model.train()
for images, labels in tqdm(train_loader): # 获得一个 batch 的数据和标注
batch_idx += 1
log_train = train_one_batch(images, labels)
df_train_log = df_train_log._append(log_train, ignore_index=True)
wandb.log(log_train)
lr_scheduler.step()
## 测试阶段
model.eval()
log_test = evaluate_testset()
df_test_log = df_test_log._append(log_test, ignore_index=True)
wandb.log(log_test)
# 保存最新的最佳模型文件
if log_test['test_accuracy'] > best_test_accuracy:
# 删除旧的最佳模型文件(如有)
old_best_checkpoint_path = 'models/best-{:.3f}.pth'.format(best_test_accuracy)
if os.path.exists(old_best_checkpoint_path):
os.remove(old_best_checkpoint_path)
# 保存新的最佳模型文件
best_test_accuracy = log_test['test_accuracy']
new_best_checkpoint_path = 'models/best-{:.3f}.pth'.format(log_test['test_accuracy'])
torch.save(model, new_best_checkpoint_path)
print('保存新的最佳模型', 'models/best-{:.3f}.pth'.format(best_test_accuracy))
# best_test_accuracy = log_test['test_accuracy']
df_train_log.to_csv('训练日志-训练集.csv', index=False)
df_test_log.to_csv('训练日志-测试集.csv', index=False)
后端程序代码
from django.core.files.base import ContentFile
from django.core.files.storage import default_storage
from django.http import HttpResponse
from django.shortcuts import render
from django.views.decorators.csrf import csrf_exempt
import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
from resnest.torch import resnest50
import os
from PIL import Image
from Herbal_Identification import settings
from app01.models import Herb
# %matplotlib inline
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 获取项目根目录
BASE_DIR = settings.BASE_DIR
# 加载模型
model_path = os.path.join(BASE_DIR, 'app01', 'models', 'best-0.978.pth')
# 加载模型
model = torch.load(model_path, map_location=device)
model = model.eval().to(device)
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Create your views here.
@csrf_exempt
def uploadimg(request):
if request.method == 'POST' and request.FILES.get('image'):
img_file = request.FILES['image']
# 保存图像文件.
img_storage_path = "app01/Input/" + img_file.name
filename = default_storage.save(img_storage_path, ContentFile(img_file.read()))
# 获取保存的文件路径
img_path = default_storage.path(filename)
try:
# 使用Pillow打开图像
img = Image.open(img_path)
input_img = test_transform(img) # 预处理
input_img = input_img.unsqueeze(0).to(device)
# 执行前向预测,得到所有类别的 logit 预测分数
pred_logits = model(input_img)
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
n = 1
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
idx = int(pred_ids)
herb = Herb.objects.get(pk = idx+1)
jpg_name = herb.name + ".jpg"
print(jpg_name)
img_path = os.path.join('static', jpg_name)
data = herb.name + "$" + herb.info + "$" + img_path
print(data)
# 最后记得关闭文件和图像
img.close()
return HttpResponse(data)
except FileNotFoundError:
return HttpResponse('File not found!')
else:
return HttpResponse('No image uploaded or invalid request method!')
前端界面代码
<template>
<div id="app">
<el-row>
<el-col :span="8" style="margin: 15px;">
<el-card style="height:600px">
<el-form>
<el-form-item>
<el-upload drag action="http://127.0.0.1:8000/uploadImg/" :on-success="onSuccess"
:before-upload="beforeUpload" :file-list="[]" accept="image/*"
style="display: flex; justify-content: center;" multiple>
<el-image v-if="img" :src="imgUrl" class="uploaded-image"></el-image>
<div v-else>
<i class="el-icon-upload"></i>
<div class="el-upload__text">将文件拖到此处,或<em>点击上传</em></div>
<div class="el-upload__text">只能上传jpg/png文件</div>
</div>
</el-upload>
</el-form-item>
<el-form-item style="display: flex; justify-content: center; gap: 40px;">
<el-button type="primary" @click="onSubmit"
style="padding-right: 40px; padding-left: 40px; margin-right: 20px;">提交</el-button>
<el-button type="success" @click="reset"
style="padding-right: 40px;padding-left: 40px; margin-left: 20px;">重置</el-button>
</el-form-item>
</el-form>
<p v-if="sort">识别结果:{{ sort }}</p>
</el-card>
</el-col>
<el-col :span="14" style="margin: 15px;">
<el-card v-if="intro" style="height:600px;">
<img :src="info_img" style="width:400px; height:300px" alt="Image">
<p>{{ intro }}</p>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script>
export default {
name: 'App',
data() {
return {
img: null,
imgUrl: '',
sort: '',
intro: '',
info_img: '',
};
},
methods: {
beforeUpload(file) {
this.img = file;
this.imgUrl = URL.createObjectURL(file);
this.sort = null;
this.intro = null;
return false;
},
onSubmit() {
// 指向this
let _this = this
if (!this.img) {
return;
}
let formData = new FormData();
formData.append('image', this.img);
this.$axios.post('http://127.0.0.1:8000/uploadImg/', formData)
.then(res => {
let resData = res.data
let dataArray = resData.split("$")
console.log(dataArray)
_this.sort = dataArray[0]
_this.intro = dataArray[1]
_this.info_img = "http://127.0.0.1:8000/" + dataArray[2]
})
.catch(err => {
console.log(err);
});
},
onSuccess(response) {
console.log(response);
},
reset() {
this.img = null;
this.imgUrl = null;
this.sort = null;
this.intro = null;
this.info_img = null;
}
},
};
</script>
<style lang="less" scoped>
.uploaded-image {
width: 100%;
height: 100%;
object-fit: cover;
}
</style>