基于PaddlePaddle的图像分类

基于PaddlePaddle的图像分类


本文使用了百度PaddlePaddle模型库中的模型,数据库和代码,这里只是讲解如何利用该平台上开源项目进行改进。

一、准备

模型链接
点击链接会进入模型界面,如图:
可以选择启动环境直接在线运行,也可以将文档中的代码复制下来进行简单修改后,在自己电脑上运行,当然,还需要下载数据集。如果不想自己训练的话,作者已经把训练好的模型放到了环境中,可以直接下载。
如上图,最左侧的freeze-model是作者训练好的模型,可以直接下载,使用的时候直接导入。data为数据集,里面包含了训练接和测试集。

二、离线运行

将文档中的代码按自己所需进行复制,放到自己的编译器中,我用的是Pycharm,Python3.8。记得先安装好Paddle,安装教程在官网有:安装
一切都准备好,下面就是配置环境和改代码了:
在工程中新建一个文件夹,专门用来做这个项目,这样能看上去更简洁明了。文件夹中再分两个子文件夹,一个用来放数据集,一个用来放模型,代码直接放在主文件夹中,就像这样:

我这里已经把代码修改好了,数据集和模型的路径什么的就现用现改,哪里有问题就再改,反正这个程序中用了很多名称来代指文件的路径,在复制的时候要把前面的字典定义也要放进去,代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import time
import codecs
import paddle.fluid as fluid
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt

target_size = [3, 224, 224]
mean_rgb = [127.5, 127.5, 127.5]
data_dir = "data"
eval_file = "eval.txt"
use_gpu = True
place = fluid.CPUPlace() if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
save_freeze_dir = "freeze-model"
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname=save_freeze_dir,
                                                                                      executor=exe)


# print(fetch_targets)


def crop_image(img, target_size):
    width, height = img.size
    w_start = (width - target_size[2]) / 2
    h_start = (height - target_size[1]) / 2
    w_end = w_start + target_size[2]
    h_end = h_start + target_size[1]
    img = img.crop((w_start, h_start, w_end, h_end))
    return img


def resize_img(img, target_size):
    ret = img.resize((target_size[1], target_size[2]), Image.BILINEAR)
    return ret


def read_image(img_path):
    img = Image.open(img_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    plt.imshow(img)
    img = crop_image(img, target_size)
    img = np.array(img).astype('float32')
    img -= mean_rgb
    img = img.transpose((2, 0, 1))  # HWC to CHW
    img *= 0.007843
    img = img[np.newaxis, :]
    return img


def infer(image_path):
    tensor_img = read_image(image_path)
    label = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets)
    return np.argmax(label)


# def eval_all():
#     eval_file_path = os.path.join(data_dir, eval_file)
#     total_count = 0
#     right_count = 0
#     with codecs.open(eval_file_path, encoding='utf-8') as flist:
#         lines = [line.strip() for line in flist]
#         t1 = time.time()
#         for line in lines:
#             total_count += 1
#             parts = line.strip().split()
#             result = infer(parts[0])
#             # print("infer result:{0} answer:{1}".format(result, parts[1]))
#             if str(result) == parts[1]:
#                 right_count += 1
#         period = time.time() - t1
#         print("total eval count:{0} cost time:{1} predict accuracy:{2}".format(total_count, "%2.2f sec" % period,
#                                                                                right_count / total_count))



# def predict1():
#     eval_file_path = os.path.join(data_dir, eval_file)
#     total_count = 0
#     right_count = 0
#     with codecs.open(eval_file_path, encoding='utf-8') as flist:
#         lines = [line.strip() for line in flist]
#         t1 = time.time()
#         labels=["daisy(菊花)","dandelion(蒲公英)","rose(玫瑰)","sunflower(向日葵)","tulip(郁金香)"]
#         for line in lines:
#             total_count += 1
#             parts = line.strip().split()
#             result = infer(parts[0])
#             # print("infer result:{0} answer:{1}".format(result, parts[1]))
#             if str(result) == parts[1]:
#                 right_count += 1
#             plt.rcParams['font.sans-serif'] = ['SimHei']
#             plt.rcParams['axes.unicode_minus'] = False
#             plt.title('{}'.format(labels[int(result)]))
#             # print(labels[int(result)])
#         period = time.time() - t1
#         print("total eval count:{0} cost time:{1}".format(total_count, "%2.2f sec" % period,
#                                                                                right_count / total_count))


def predict(img_path1):
    t1 = time.time()
    labels=["daisy(菊花)","dandelion(蒲公英)","rose(玫瑰)","sunflower(向日葵)","tulip(郁金香)"]
    result = infer(img_path1)
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.title('{}'.format(labels[int(result)]))
    # print(labels[int(result)])
    period = time.time() - t1
    print("cost time:{}".format("%2.2f sec" % period))

if __name__ == '__main__':
    image_path = "data/evalImageSet/4414080766_5116e8084e.jpg"
    predict(image_path)

可以看出来,原作者的代码中是直接输出的在测试集中的准确率,注释掉的就是原来的输出代码段,我修改了他的predict函数,在输入一张图片之后,直接输出这张图片及其预测的标签,这样看起来更清晰,也可以直接使用自己拍摄的照片(我没试过)。

注意:

直接下载下来的数据集不能直接用于训练和预测,因为还没有进行数据预处理,需要处理后将其转换为标准格式,并将一些打不开的文件进行清洗。另外,在数据集中还有几个txt文件,这是记录图片名称和标签的,程序中有用到,需要放到正确的位置。
这样就可以运行了,输出效果如下:
PS:如果哪里出了问题,就照着问题改,很可能是文件路径出了问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值