利用caffe提取特征,主要有两种方法:
方法1:利用命令行的方式,可以参考:http://caffe.berkeleyvision.org/gathered/examples/feature_extraction.html,这种方式提取的特征是lmdb形式的。为了可视化特征,可以先将lmdb格式的数据转换为mat格式的数据,再利用matlab来可视化特征。本方法主要参考博客:http://blog.csdn.net/jiandanjinxin/article/details/50410290
1)在caffe 根目录下创建feat_helper_pb2.py 和lmdb2mat.py,这两个文件的作用就是将lmdb转换为mat,具体内容如下:
feat_helper_pb2.py :
# Generated by the protocol buffer compiler. DO NOT EDIT!
from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import reflection
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
DESCRIPTOR = descriptor.FileDescriptor(
name='datum.proto',
package='feat_extract',
serialized_pb='\n\x0b\x64\x61tum.proto\x12\x0c\x66\x65\x61t_extract\"i\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02')
_DATUM = descriptor.Descriptor(
name='Datum',
full_name='feat_extract.Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
descriptor.FieldDescriptor(
name='channels', full_name='feat_extract.Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='height', full_name='feat_extract.Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='width', full_name='feat_extract.Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='data', full_name='feat_extract.Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value="",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='label', full_name='feat_extract.Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='float_data', full_name='feat_extract.Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
serialized_start=29,
serialized_end=134,
)
DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
class Datum(message.Message):
__metaclass__ = reflection.GeneratedProtocolMessageType
DESCRIPTOR = _DATUM
# @@protoc_insertion_point(class_scope:feat_extract.Datum)
# @@protoc_insertion_point(module_scope)
lmdb2mat.py:
import lmdb
import feat_helper_pb2
import numpy as np
import scipy.io as sio
import time
def main(argv):
lmdb_name = sys.argv[1]
print "%s" % sys.argv[1]
batch_num = int(sys.argv[2]);
batch_size = int(sys.argv[3]);
window_num = batch_num*batch_size;
start = time.time()
if 'db' not in locals().keys():
db = lmdb.open(lmdb_name)
txn= db.begin()
cursor = txn.cursor()
cursor.iternext()
datum = feat_helper_pb2.Datum()
keys = []
values = []
for key, value in enumerate( cursor.iternext_nodup()):
keys.append(key)
values.append(cursor.value())
ft = np.zeros((window_num, int(sys.argv[4])))
for im_idx in range(window_num):
datum.ParseFromString(values[im_idx])
ft[im_idx, :] = datum.float_data
print 'time 1: %f' %(time.time() - start)
sio.savemat(sys.argv[5], {'feats':ft})
print 'time 2: %f' %(time.time() - start)
print 'done!'
if __name__ == '__main__':
import sys
main(sys.argv)
2)执行python lmdb2mat.py IN BATCHNUM BATCHSIZE DIM OUT,这样就将lmdb文件转换为mat文件,
其中,BATCHSIZE表示batch的大小,BATCHNUM表示有多少个batch, DIM表示feature的长度(需要手动计算),IN 表示lmdb文件, OUT表示mat文件。
3)对mat文件里的特征进行可视化
display_network.m:
function [h, array] = display_network(A, opt_normalize, opt_graycolor, cols, opt_colmajor)
% This function visualizes filters in matrix A. Each column of A is a
% filter. We will reshape each column into a square image and visualizes
% on each cell of the visualization panel.
% All other parameters are optional, usually you do not need to worry
% about it.
% opt_normalize: whether we need to normalize the filter so that all of
% them can have similar contrast. Default value is true.
% opt_graycolor: whether we use gray as the heat map. Default is true.
% cols: how many columns are there in the display. Default value is the
% squareroot of the number of columns in A.
% opt_colmajor: you can switch convention to row major for A. In that
% case, each row of A is a filter. Default value is false.
warning off all
if ~exist('opt_normalize', 'var') || isempty(opt_normalize)
opt_normalize= true;
end
if ~exist('opt_graycolor', 'var') || isempty(opt_graycolor)
opt_graycolor= true;
end
if ~exist('opt_colmajor', 'var') || isempty(opt_colmajor)
opt_colmajor = false;
end
% rescale
A = A - mean(A(:));
if opt_graycolor, colormap(gray); end
% compute rows, cols
[L M]=size(A);
sz=sqrt(L);
buf=1;
if ~exist('cols', 'var')
if floor(sqrt(M))^2 ~= M
n=ceil(sqrt(M));
while mod(M, n)~=0 && n<1.2*sqrt(M), n=n+1; end
m=ceil(M/n);
else
n=sqrt(M);
m=n;
end
else
n = cols;
m = ceil(M/n);
end
array=-ones(buf+m*(sz+buf),buf+n*(sz+buf));
if ~opt_graycolor
array = 0.1.* array;
end
if ~opt_colmajor
k=1;
for i=1:m
for j=1:n
if k>M,
continue;
end
clim=max(abs(A(:,k)));
if opt_normalize
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
else
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/max(abs(A(:)));
end
k=k+1;
end
end
else
k=1;
for j=1:n
for i=1:m
if k>M,
continue;
end
clim=max(abs(A(:,k)));
if opt_normalize
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
else
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)';
end
k=k+1;
end
end
end
if opt_graycolor
h=imagesc(array);
else
h=imagesc(array,'EraseMode','none',[-1 1]);
end
axis image off
drawnow;
warning on all
进入 /examples/_temp/,执行matlab程序:
cd ./examples/_temp/
matlab
在matlab中输入下面的命令:
nsample = 2;
% num_output = 96; % conv1
% num_output = 256; % conv5
num_output = 4096; % fc7
load features_fc7.mat
width = size(feats, 2);
nmap = width / num_output;
for i = 1 : nsample
feat = feats(i, :);
feat = reshape(feat, [nmap num_output]);
figure('name', sprintf('image #%d', i));
display_network(feat);
end
方法2:利用python提取特征,并进行可视化,参考:http://nbviewer.jupyter.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb,对于某一层的特征,可以采用如下方式保存在硬盘中:
feat_fc6 = net.blobs['fc6'].data[0]
feat_fc6.shape = (4096,1)
row_feat_fc6 = numpy.transpose(feat_fc6)
numpy.savetxt(./caffe/examples/_temp/ + "features.txt",row_feat_fc6)
scipy.io.savemat(./caffe/examples/_temp/ + "features.mat", {'feature':row_feat_fc6})