Caffe:使用 classify.py 批量对图片分类

一般使用 Caffe 训练完网络后,会用 test.bin 来测试一下网络的精度,然后还能用 classification.bin 来用网络对图片进行单张的分类,但是一张一张的分,效率很低,所以我改写了 classify.py 文件,使其读取 test.txt 文件批量分类,输出具体哪一张图片分错了。

代码如下:

# copyright (c) strongnine

import caffe
import sys
import os
import numpy as np
 
caffe_root = '/path/to/your/caffe/' # 指定 caffe 的路径
sys.path.insert(0,caffe_root+'python')
 
caffe.set_mode_gpu()
 
deploy = caffe_root+'models/bvlc_alexnet/deploy.prototxt' ##
caffe_model = caffe_root+'model/outputs/caffe_alexnet_train_iter_450000.caffemodel' ## 

labels_name = caffe_root+'data/alexnet/synset_words.txt'
labels = np.loadtxt(labels_name, str, delimiter='\t')
for i in range(len(labels)):
    exec(labels[i] + "=0")
right = 0
false = 0
mean_file = caffe_root+'data/alexnet/train_mean.npy' # 由 imagenet_mean.binaryproto 转换来
net = caffe.Net(deploy, caffe_model, caffe.TEST)

transformer=caffe.io.Transformer({'data':net.blobs['data'].data.shape})
transformer.set_transpose('data',(2,0,1))
transformer.set_mean('data',np.load(mean_file).mean(1).mean(1))
transformer.set_raw_scale('data',255)
transformer.set_channel_swap('data',(2,1,0))


test_file = open(caffe_root+'data/alexnet/test.txt', 'r')
test_data = test_file.readlines()

log = open(caffe_root+'data/alexnet/log/classify_log.log', 'w')
image_road = '/your/image/path/'

for line in test_data:
    split = line.split(' ')
    image = caffe.io.load_image(image_road + split[0])
    net.blobs['data'].data[...]=transformer.preprocess('data',image)
 
    out = net.forward()
 
    prob = net.blobs['prob'].data[0].flatten()
    top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1]
    
    log.write(split[0] + ' ' + split[1][0] + ' ' + str(top_k[0]))

    if str(top_k[0]) == split[1][0]:
        right += 1
        log.write(' right\n')
    else:
        false += 1
        log.write(' false\n')

print(right)
print(false)
print(right/float(right + false))

运行完成后会输出分类正确的图片数量,和分类错误的图片数量,以及所有的正确率。

生成完查看 log 文件:

...
cat_01.jpg 0 0 right
cat_02.jpg 0 0 right
person_04.jpg 1 1 right
person_05.jpg 1 0 false
...

第一个数字为标签类别,第二个数字为分类类别。

©️2020 CSDN 皮肤主题: 像素格子 设计师: CSDN官方博客 返回首页
实付0元
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值