Python接口调用已训练好的 caffemodel 测试分类

训练好caffemodel后,需要测试模型分类的正确率,caffe 有 python接口,可以调用已训练好的caffemodel测试分类。
有以下几点需要注意:

1, 需要修改 net.prototxt 文件为 deploy.prototxt 文件,方法见我的另一个博客。

deploy= '/home/justin/cnn-human/code/deploy.prototxt' #结构文件  

2, 图像的均值文件为相应的均值文件,这部分也要修改。(如果训练的时候没有加入数据的均值文件,测试的时候也不要加入均值文件测试)

transformer.set_mean('data', np.load('/home/justin/cnn-human/data/train/mean.npy').mean(1).mean(1)) 

3, 下面的image_label.txt文件为不同图像和对应的label。

 labels = np.loadtxt("/home/justin/cnn-human/data/image_label.txt", str,delimiter='\t')  

本次我用到的为:

0 uwalk
1 unrun
2 boxst
3 boxwa
4 aimwa
5 crawl

4, 其他的部分对比做相应的修改。

# -*- coding: UTF-8 -*-  
import os  
import caffe  
import numpy as np  
#root='/home/justin/caffe/'#指定根目录  
deploy= '/home/justin/cnn-human/code/deploy.prototxt' #结构文件  
caffe_model= '/home/justin/cnn-human/code/snapshots/_iter_12500.caffemodel'
#已经训练好的model  

dir = '/home/justin/cnn-human/data/testImg' #保存测试图片的集合  
filelist=[]  
filenames=os.listdir(dir)  
for fn in filenames:  
        fullfilename = os.path.join(dir,fn)  
        filelist.append(fullfilename)  
#filelist.append(fn)  
def Test(img):  
#加载模型  
        net = caffe.Net(deploy,caffe_model,caffe.TEST)  
        #网络设置为测试模式

# 加载输入和配置预处理  
        transformer = caffe.io.Transformer({'data':net.blobs['data'].data.shape})  #不懂什么意思?
        transformer.set_mean('data', np.load('/home/justin/cnn-human/data/train/mean.npy').mean(1).mean(1))  
        #减去均值?
        transformer.set_transpose('data', (2,0,1))  
        #python读取的图片文件格式为H×W×K,需转化为K×H×W
        transformer.set_channel_swap('data', (2,1,0))  
        #交换通道,将图片由RGB变为BGR(如果是灰度图片,此处可以注释,因为没有RGB一说,不注释会报:Exception: Channel swap needs to have the same number of dimensions as the input channels.)
        transformer.set_raw_scale('data', 255.0)  
        #缩放到[0,255]之间

#注意可以调节预处理批次的大小  
#由于是处理一张图片,所以把原来的15张的批次改为1  
        net.blobs['data'].reshape(1,3,128,128)  
        #3通道*128*128

#加载图片到数据层  
        im = caffe.io.load_image(img)  
        #加载图片(caffe.io.load_image(img,color=False),如果是灰度图片,此处第二个参数color=False一定要补上,不然默认加载成3通道图片,会报错,大致意思就是我们net里定义的是1通道的,与实际不符。ValueError: could not broadcast input array from shape (3,28,28) into shape (64,1,28,28))
        net.blobs['data'].data[...] = transformer.preprocess('data', im) 
        #执行上面设置的图片预处理操作,并将图片载入到blob中 

#前向计算  

        out = net.forward()  

# 其他可能的形式 : # out = net.forward_all(data=np.asarray([transformer.preprocess('data', im)]))  

#预测分类  
        print out['prob']
        print out['prob'].argmax()

# print("Predicted class is #{}.".format(out['prob'].softmax()))      

    #打印预测标签  
        labels = np.loadtxt("/home/justin/cnn-human/data/image_label.txt", str,delimiter='\t')  
        top_k = net.blobs['prob'].data[0].flatten().argsort()[-1]  
        #输出概率最大的类别的下标
        print 'the class is:',labels[top_k]  
        f=file("/home/justin/cnn-human/data/label.txt","a")  
        f.writelines(img+' '+labels[top_k]+'\n')  
#循环遍历文件夹root+'examples/images/'下的所有图片  
for i in range(0,len(filelist)):  
    img=filelist[i]  
    Test(img)  

1,net.blobs[‘data’].data.shape的解释
2,numpy.loadtxt的解释

可以参考:
http://blog.csdn.net/baterforyou/article/details/71430284
https://www.cnblogs.com/denny402/p/5111018.html
http://blog.csdn.net/u010925447/article/details/75805474
http://blog.csdn.net/yxq5997/article/details/53780394?utm_source=itdadao&utm_medium=referral

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值