caffe分类器test通用python程序

       caffe分类器test,在pycaffe中提供有示例程序,并且它自己也封装了一个类,但感觉不是太友好,在我在原始基础上做了一些改动,方便以后测试用,在此整理如下:

       首先展示一下caffe官方提供的分类器test脚本大致如下:

caffe_root='/data/caffe/'
sys.path.insert(0, caffe_root + 'python')
os.chdir(caffe_root)
import caffe


def imagenetclassifyexample():
    net_file=caffe_root + 'models/bvlc_reference_caffenet/resnet18_deploy.prototxt'
    caffe_model=caffe_root + 'models/bvlc_reference_caffenet/resnet18.caffemodel'
    mean_file=caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'

    net = caffe.Net(net_file,caffe_model,caffe.TEST)
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data', (2,0,1))
    meanvals=np.load(mean_file).mean(1).mean(1)
    transformer.set_mean('data', meanvals)
    transformer.set_raw_scale('data', 255)
    transformer.set_channel_swap('data', (2,1,0))

    im=caffe.io.load_image(caffe_root+'examples/images/cat.jpg')
    net.blobs['data'].data[...] = transformer.preprocess('data',im)
    out = net.forward()


    imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'
    labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')
    probs=net.blobs['prob'].data[0].flatten()
    probsortids = net.blobs['prob'].data[0].flatten().argsort()
    top_k=probsortids[-1:-6:-1]
    for i in np.arange(top_k.size):
        print(top_k[i],probs[top_k[i]],labels[top_k[i]])

上述代码具有普适性,但不方便批量测试和其他地方调用,为此分装了模块化如下:

#https://www.jianshu.com/p/5155fe9d109b
#https://www.zhihu.com/question/51621434

import numpy as np
import sys,os,json
caffe_root='/data/caffe/'
sys.path.insert(0, caffe_root + 'python')
os.chdir(caffe_root)
import caffe

class Classify:
    def __init__(self,net_file,caffe_model,meanvals,labels):
        net = caffe.Net(net_file, caffe_model, caffe.TEST)
        transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
        transformer.set_transpose('data', (2, 0, 1))## 将 高x宽x通道,转化为,通道x高x宽的格式
        if meanvals is not None:
            transformer.set_mean('data', meanvals)
            transformer.set_raw_scale('data', 255)#像素数值恢复[0-255] ,rescale from [0,1] to [0,255]
        transformer.set_channel_swap('data', (2, 1, 0))# swap channels from RGB to BGR
        self.net=net
        self.transformer=transformer
        self.labels=labels


    def predict(self,impath,top_k_num=1):
        im = caffe.io.load_image(impath)
        self.net.blobs['data'].data[...] = self.transformer.preprocess('data', im)
        out = self.net.forward()

        probs = self.net.blobs['prob'].data[0].flatten()
        probsortids = self.net.blobs['prob'].data[0].flatten().argsort()
        top_k = probsortids[-1:-(len(self.labels)):-1]

        return self.labels[top_k[0:top_k_num]], probs[top_k[0:top_k_num]]


    def get_classify_acc(self, immaindir):
        classify_acc = {}
        for label in labels:
            imdir = os.path.join(immaindir, label)
            if not os.path.isdir(imdir):
                continue
            imnames = os.listdir(imdir)
            imnum = len(imnames)
            i, k = 0, 0
            for imname in imnames:
                impath = os.path.join(imdir, imname)
                predlabel, prob = self.predict(impath)
                i += 1
                print(predlabel, prob, imname)
                if label == predlabel:
                    k += 1
            print(imnum, k, float(k) / imnum)
            classify_acc[label] = round(float(k) / imnum, 3)

        return classify_acc

现在整理及模块化封装后,感觉是不是整洁多了,也方便其他地方调用

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值