本文介绍采用NVIDIA digits训练的模型对图片数据集进行预测分类,使用caffe训练的模型同样有效,在此主要介绍使用digits训练的模型。
一、环境配置
1、digits环境安装
具体不介绍了,官方有:digits安装
2、caffe环境(可选)
因为digits已将caffe封装,可直接安装NVIDIA的digits,当然有caffe环境的可不看
可按照官方教程安装编译caffe(官方安装说明)
二、python实现
# -*- coding:utf-8 -*-
import numpy as np
import sys,os,caffe
import json
import shutil
#统计分类后的图片数
glassesNum = 0
no_glassesNum = 0
def model_classify(image_path):
# 如果环境变量里没有配置,可将注释去掉
# sys.path.insert(0, caffe_root + 'python')
# os.chdir(caffe_root)
net = caffe.Net(net_file,caffe_model,caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2,1,0)) # if using RGB instead of BGR
img = caffe.io.load_image(image_path)
net.blobs['data'].data[...] = transformer.preprocess('data',img)
out = net.forward()
labels = np.loadtxt(labels_file, str, delimiter='\t')
top1 = net.blobs['softmax'].data[0].flatten()
top_k = top1.argsort()[-1:-6:-1]
value = round(top1[top_k[0]]*100,2)
key = str(labels[top_k[0]])
value1 = round(top1[top_k[1]]*100,2)
key1 = str(labels[top_k[1]])
keytemp =work_root+key
keytemp1 =work_root+key1
unknow =work_root+'unknow'
data = {}
data[key] = value
data[key1] = value1
jsonstr = json.dumps(data)
if not os.path.exists(keytemp):
os.makedirs(keytemp)
if not os.path.exists(keytemp1):
os.makedirs(keytemp1)
if not os.path.exists(unknow):
os.makedirs(unknow)
image_path_target='/'+key+str(value)+'_'+image_path.split('/')[-1]
image_path_target1='/'+key1+str(value1)+'_'+image_path.split('/')[-1]
if value >95.00:
shutil.copy(image_path,keytemp+image_path_target)
global no_glassesNum
no_glassesNum +=1
print "================="+str(no_glassesNum+glassesNum)+'/'+filesum,jsonstr+"========================================"
if value1 > 95.00:
shutil.copy(image_path,keytemp1+image_path_target1)
global glassesNum
glassesNum +=1
print "================="+str(no_glassesNum+glassesNum)+'/'+filesum,jsonstr+"============================================"
else:
shutil.copy(image_path,unknow+image_path_target)
# 转换bp格式图像均值文件为npy格式
def BpToNpy():
#sys.path.insert(0, caffe_root + 'python')
MEAN_PROTO_PATH = 'mean.binaryproto' # 待转换的pb格式图像均值文件路径
MEAN_NPY_PATH = 'mean.npy' # 转换后的numpy格式图像均值文件路径
blob = caffe.proto.caffe_pb2.BlobProto() # 创建protobuf blob
data = open(MEAN_PROTO_PATH, 'rb' ).read() # 读入mean.binaryproto文件内容
blob.ParseFromString(data) # 解析文件内容到blob
array = np.array(caffe.io.blobproto_to_array(blob))# 将blob中的均值转换成numpy格式,array的shape (mean_number,channel, hight, width)
mean_npy = array[0] # 一个array中可以有多组均值存在,故需要通过下标选择其中一组均值
np.save(MEAN_NPY_PATH ,mean_npy)
# 获取模型文件
def getFileName(path):
global net_file,caffe_model,labels_file,mean_file
f_list = os.listdir(path)
# print f_list
for filename in f_list:
# os.path.splitext():分离文件名与扩展名
if os.path.splitext(filename)[1] == '.prototxt':
net_file = work_root+filename
if os.path.splitext(filename)[1] == '.caffemodel':
caffe_model = work_root+filename
if os.path.splitext(filename)[1] == '.txt':
labels_file = work_root+filename
if os.path.splitext(filename)[1] == '.npy':
mean_file = work_root+filename
if __name__ == '__main__':
work_root = os.getcwd()+'/'
f_list = os.listdir(work_root)
for filename in f_list:
if not filename.endswith('npy'):
BpToNpy()
getFileName(work_root)
#如果环境中没有设置caffe工作路径,设置路径
#caffe_root = '/dataTwo/caffe-ssd'
image_path = raw_input("Input your image path: ")
#判断路径是否存在
if os.path.exists(image_path):
#判断是否为目录路径
if os.path.isdir(image_path):
filesum =str(len(sum([i[2] for i in os.walk(image_path)],[])))
filenames = os.listdir(image_path)
for fn in filenames:
fullfilename = os.path.join(image_path,fn)
model_classify(fullfilename)
else:
#判断是否为列表文件
if os.path.exists(image_path):
imglist = open(image_path)
line = imglist.readline()
while line:
print line,
line = imglist.readline().replace('\n','').replace('\r\n','')
model_classify(line)
imglist.close()
model_classify(image_path)
else:
iamgefile = urllib.urlopen(image_path)
status=iamgefile.code
#判路路径是否为网络路径
if(status==200):
image_data = iamgefile.read()
#获取图片名
image_name = os.path.basename(image_path)
#创建新的图片地址
new_imagepath = filepath+"/"+image_name
#保存图片
with open(new_imagepath, 'wb') as code:
code.write(image_data)
model_classify(new_imagepath)
else:
print "not found the folder!"
这里就不过多介绍了,可参考我的GitHub:wulivicte/caffe_classify