一、方案背景
蘑菇可分为食用蘑菇和有毒蘑菇。可食用蘑菇营养丰富、味道鲜美,受到大家的喜爱。然而,部分有毒蘑菇和食用蘑菇宏观特征极其相似,普通人难以正确分辨。有毒蘑菇的危害很大:根据云南省疾控中心统计数据,2021年云南省食用野生菌中毒事件达600余起,导致2000多人中毒,20多人死亡。对于有毒蘑菇的鉴别,目前除了医院、检测机构和专家学者,没有简便高效的手段和途径。在此背景下,本方案利用开源的毒蘑菇图片数据集(https://www.kaggle.com/datasets/stepandupliak/predict-poison-mushroom-by-photo),提出一个基于ResNet的毒蘑菇识别模型,可以较为精确快捷地判别蘑菇类型,从而为人们日常分辨有毒蘑菇提供有效指引。
二、方案简介
本方案通过开源的毒蘑菇数据集,基于神经网络模型ResNet进行图像分类的微调,训练过程中采用图像增强的方法来提升模型的鲁棒性,同时利用Intel® Extension for PyTorch*工具包加速模型的推理,从而得到一个分类精度较高和推理速度快的毒蘑菇识别模型,具有低成本和方便快捷的特点,可为人们日常生活和野外鉴别有毒蘑菇提供精确及时的指引,并且有利于减少人们误食有毒蘑菇等事故的发生。
三、建模流程图
四、硬件配置
五、推理结果对比分析
注:1、平均推理耗时 = 模型对测试数据重复10次推理的总耗费时间的平均值 / 1052
2、IPEX指的是intel_extension_for_pytorch
六、部分推理对比分析的关键代码
import intel_extension_for_pytorch as ipex
def inference_with_ipex(test_loader, model_path):
model = MushroomNet()
model.load_state_dict(torch.load(model_path))
model.eval()
model = ipex.optimize(model)
val_acc = 0.0
test_pred = []
with torch.no_grad():
start_time = time.time()
for i, (input, target) in enumerate(test_loader):
input = input.cpu()
target = target.cpu()
# compute output
output = model(input)
test_pred.append(output.data.cpu().numpy())
end_time = time.time()
print('累计用时:{}'.format(end_time-start_time))
return np.vstack(test_pred)
def inference_without_ipex(test_loader, model_path):
model = MushroomNet()
model.load_state_dict(torch.load(model_path))
model.eval()
val_acc = 0.0
test_pred = []
with torch.no_grad():
start_time = time.time()
for i, (input, target) in enumerate(test_loader):
input = input.cpu()
target = target.cpu()
# compute output
output = model(input)
test_pred.append(output.data.cpu().numpy())
end_time = time.time()
print('累计用时:{}'.format(end_time-start_time))
return np.vstack(test_pred)
上图中test_loader是测试加载管道(batch_size设为32),model_path是最佳模型读取路径,inference_with_ipex是具有ipex加速推理的集成函数,inference_without_ipex是没有ipex加速推理的集成函数。
七、总结