英特尔oneAPI人工智能黑客松第二季-基于ResNet的毒蘑菇识别

一、方案背景

       蘑菇可分为食用蘑菇和有毒蘑菇。可食用蘑菇营养丰富、味道鲜美,受到大家的喜爱。然而,部分有毒蘑菇和食用蘑菇宏观特征极其相似,普通人难以正确分辨。有毒蘑菇的危害很大:根据云南省疾控中心统计数据,2021年云南省食用野生菌中毒事件达600余起,导致2000多人中毒,20多人死亡。对于有毒蘑菇的鉴别,目前除了医院、检测机构和专家学者,没有简便高效的手段和途径。在此背景下,本方案利用开源的毒蘑菇图片数据集(https://www.kaggle.com/datasets/stepandupliak/predict-poison-mushroom-by-photo),提出一个基于ResNet的毒蘑菇识别模型,可以较为精确快捷地判别蘑菇类型,从而为人们日常分辨有毒蘑菇提供有效指引。

二、方案简介

       本方案通过开源的毒蘑菇数据集,基于神经网络模型ResNet进行图像分类的微调,训练过程中采用图像增强的方法来提升模型的鲁棒性,同时利用Intel® Extension for PyTorch*工具包加速模型的推理,从而得到一个分类精度较高和推理速度快的毒蘑菇识别模型,具有低成本和方便快捷的特点,可为人们日常生活和野外鉴别有毒蘑菇提供精确及时的指引,并且有利于减少人们误食有毒蘑菇等事故的发生。

三、建模流程图

 四、硬件配置

 五、推理结果对比分析

注:1、平均推理耗时 = 模型对测试数据重复10次推理的总耗费时间的平均值 / 1052        

       2、IPEX指的是intel_extension_for_pytorch

六、部分推理对比分析的关键代码

import intel_extension_for_pytorch as ipex

def inference_with_ipex(test_loader, model_path):
    model = MushroomNet()
    model.load_state_dict(torch.load(model_path))
    model.eval()
    model = ipex.optimize(model)
    
    val_acc = 0.0
    
    test_pred = []
    with torch.no_grad():
        start_time = time.time()
        for i, (input, target) in enumerate(test_loader):
            input = input.cpu()
            target = target.cpu()

            # compute output
            output = model(input)
            test_pred.append(output.data.cpu().numpy())
        end_time = time.time()
        print('累计用时:{}'.format(end_time-start_time))
    return np.vstack(test_pred)

def inference_without_ipex(test_loader, model_path):
    model = MushroomNet()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    val_acc = 0.0
    
    test_pred = []
    with torch.no_grad():
        start_time = time.time()
        for i, (input, target) in enumerate(test_loader):
            input = input.cpu()
            target = target.cpu()

            # compute output
            output = model(input)
            test_pred.append(output.data.cpu().numpy())
        end_time = time.time()
        print('累计用时:{}'.format(end_time-start_time))
    return np.vstack(test_pred)

上图中test_loader是测试加载管道(batch_size设为32),model_path是最佳模型读取路径,inference_with_ipex是具有ipex加速推理的集成函数,inference_without_ipex是没有ipex加速推理的集成函数。

七、总结

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值