本文使用了百度PaddlePaddle模型库中的模型,数据库和代码,这里只是讲解如何利用该平台上开源项目进行改进。
一、准备
模型链接
点击链接会进入模型界面,如图:
可以选择启动环境直接在线运行,也可以将文档中的代码复制下来进行简单修改后,在自己电脑上运行,当然,还需要下载数据集。如果不想自己训练的话,作者已经把训练好的模型放到了环境中,可以直接下载。
如上图,最左侧的freeze-model是作者训练好的模型,可以直接下载,使用的时候直接导入。data为数据集,里面包含了训练接和测试集。
二、离线运行
将文档中的代码按自己所需进行复制,放到自己的编译器中,我用的是Pycharm,Python3.8。记得先安装好Paddle,安装教程在官网有:安装
一切都准备好,下面就是配置环境和改代码了:
在工程中新建一个文件夹,专门用来做这个项目,这样能看上去更简洁明了。文件夹中再分两个子文件夹,一个用来放数据集,一个用来放模型,代码直接放在主文件夹中,就像这样:
我这里已经把代码修改好了,数据集和模型的路径什么的就现用现改,哪里有问题就再改,反正这个程序中用了很多名称来代指文件的路径,在复制的时候要把前面的字典定义也要放进去,代码如下:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import time
import codecs
import paddle.fluid as fluid
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt
target_size = [3, 224, 224]
mean_rgb = [127.5, 127.5, 127.5]
data_dir = "data"
eval_file = "eval.txt"
use_gpu = True
place = fluid.CPUPlace() if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
save_freeze_dir = "freeze-model"
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname=save_freeze_dir,
executor=exe)
# print(fetch_targets)
def crop_image(img, target_size):
width, height = img.size
w_start = (width - target_size[2]) / 2
h_start = (height - target_size[1]) / 2
w_end = w_start + target_size[2]
h_end = h_start + target_size[1]
img = img.crop((w_start, h_start, w_end, h_end))
return img
def resize_img(img, target_size):
ret = img.resize((target_size[1], target_size[2]), Image.BILINEAR)
return ret
def read_image(img_path):
img = Image.open(img_path)
if img.mode != 'RGB':
img = img.convert('RGB')
plt.imshow(img)
img = crop_image(img, target_size)
img = np.array(img).astype('float32')
img -= mean_rgb
img = img.transpose((2, 0, 1)) # HWC to CHW
img *= 0.007843
img = img[np.newaxis, :]
return img
def infer(image_path):
tensor_img = read_image(image_path)
label = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets)
return np.argmax(label)
# def eval_all():
# eval_file_path = os.path.join(data_dir, eval_file)
# total_count = 0
# right_count = 0
# with codecs.open(eval_file_path, encoding='utf-8') as flist:
# lines = [line.strip() for line in flist]
# t1 = time.time()
# for line in lines:
# total_count += 1
# parts = line.strip().split()
# result = infer(parts[0])
# # print("infer result:{0} answer:{1}".format(result, parts[1]))
# if str(result) == parts[1]:
# right_count += 1
# period = time.time() - t1
# print("total eval count:{0} cost time:{1} predict accuracy:{2}".format(total_count, "%2.2f sec" % period,
# right_count / total_count))
# def predict1():
# eval_file_path = os.path.join(data_dir, eval_file)
# total_count = 0
# right_count = 0
# with codecs.open(eval_file_path, encoding='utf-8') as flist:
# lines = [line.strip() for line in flist]
# t1 = time.time()
# labels=["daisy(菊花)","dandelion(蒲公英)","rose(玫瑰)","sunflower(向日葵)","tulip(郁金香)"]
# for line in lines:
# total_count += 1
# parts = line.strip().split()
# result = infer(parts[0])
# # print("infer result:{0} answer:{1}".format(result, parts[1]))
# if str(result) == parts[1]:
# right_count += 1
# plt.rcParams['font.sans-serif'] = ['SimHei']
# plt.rcParams['axes.unicode_minus'] = False
# plt.title('{}'.format(labels[int(result)]))
# # print(labels[int(result)])
# period = time.time() - t1
# print("total eval count:{0} cost time:{1}".format(total_count, "%2.2f sec" % period,
# right_count / total_count))
def predict(img_path1):
t1 = time.time()
labels=["daisy(菊花)","dandelion(蒲公英)","rose(玫瑰)","sunflower(向日葵)","tulip(郁金香)"]
result = infer(img_path1)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.title('{}'.format(labels[int(result)]))
# print(labels[int(result)])
period = time.time() - t1
print("cost time:{}".format("%2.2f sec" % period))
if __name__ == '__main__':
image_path = "data/evalImageSet/4414080766_5116e8084e.jpg"
predict(image_path)
可以看出来,原作者的代码中是直接输出的在测试集中的准确率,注释掉的就是原来的输出代码段,我修改了他的predict函数,在输入一张图片之后,直接输出这张图片及其预测的标签,这样看起来更清晰,也可以直接使用自己拍摄的照片(我没试过)。
注意:
直接下载下来的数据集不能直接用于训练和预测,因为还没有进行数据预处理,需要处理后将其转换为标准格式,并将一些打不开的文件进行清洗。另外,在数据集中还有几个txt文件,这是记录图片名称和标签的,程序中有用到,需要放到正确的位置。
这样就可以运行了,输出效果如下:
PS:如果哪里出了问题,就照着问题改,很可能是文件路径出了问题。