caffe输出分类训练时验证集识别错误的样本

本文以mnist以及lenet为例

1.将测试错误样本打印出来

当运行测试时,最后的输出层为AccuracyLayer层。AccuracyLayer对前一层全连接层ip2的10个神经元输出结果进行排序,然后将最大值所对应的神经元序号与标签label进行比较,相等则判定预测正确;否则判定预测错误。所以,首先对accuracy_layer函数进行功能添加,打开src/caffe/layers/accuracy_layer.cpp文件,添加如下代码段:

void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {
  ...
  // check if true label is in top k predictions
  for (int k = 0; k < top_k_; k++) {
    if (bottom_data_vector[k].second == label_value) {
      // 预测正确
      ...
    }
    else
    {
      // 预测错误
      // index为batch中的图片序号(0~99),label为标签值,output为预测值
      LOG(INFO) << "index:" << i << " label:" << label_value << " output:" << bottom_data_vector[k].second;
    }
  }
}

这样我们就知道在一个batch中哪些图片被预测错误,以及它的标签值和预测值。测试样本总共有10000个,分为100个batch,每个batch大小为100个,所以我们还需要输出每个batch的序号。打开src/caffe/solver.cpp文件,跳转到Slover::Test()函数中,添加如下语句:

void Solver<Dtype>::Test(const int test_net_id) {
  ...
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    // 输出batch序号
    LOG(INFO) << "batch:" << i;
  }
}

做完上述改变之后发现运行训练程序时,对caffe进行make all,让它对修改过的层重新编译。 

2.将日志输出至文件

编译完成后,设置训练输出日志文件

$./examples/mnist/train_lenet.sh 2>&1 | tee lenet.log

见《深度学习:21天实战caffe》第295页 

3.用Matlab将错误样本可视化

下面我们来写段Matlab代码,用来读取上面的日志文件,以及将MNIST数据库可视化。

clear;clc;close all;

fid = fopen('caffe.exe.txt');   % 替换为日志文件名
tline = fgetl(fid);

C = [];     % 定义空矩阵用来存放结果

while ischar(tline)
    if ~isempty(strfind(tline, 'batch:'))  % 查找字符串
        indexline = fgetl(fid);
        if ~isempty(strfind(indexline, 'batch:'))
            tline = indexline;
        elseif isempty(strfind(indexline, 'index:'))
            tline = indexline;
        else
            % 在tline中解析batch
            idx1 = strfind(tline, 'batch:');
            batch = str2num(tline(idx1 + 6 : length(tline)));
            % 在indexline中解析index,label,output
            idx2 = strfind(indexline, 'index:');
            idx3 = strfind(indexline, 'label:');
            idx4 = strfind(indexline, 'output:');
            index = str2num(indexline(idx2 + 6 : idx3 - 2));
            label = str2num(indexline(idx3 + 6 : idx4 - 2));
            output = str2num(indexline(idx4 + 7 : length(indexline)));
            % 添加到数组中
            C = [C; batch, index, label, output];
        end
    else
        tline = fgetl(fid);
    end
end

fclose(fid);

% 可视化部分
image_file_name = 't10k-images.idx3-ubyte';
fid = fopen(image_file_name);
images_data = fread(fid, 'uint8');
fclose(fid);

images_data = images_data(17:end);
image_buffer = zeros(28, 28);

for k = 1:1:size(C,1)
    figure(size(C,1));
    index = C(k,1) * 100 + C(k,2);
    image_buffer = reshape(images_data((index) * 28 * 28 + 1 : (index + 1) * 28 * 28), 28, 28);
    subplot(10, 10, k);
    imshow(uint8(image_buffer)');    % 转置
    title(sprintf('%d->%d', C(k,3), C(k,4)));   % label -> output
end

4.可视化结果

结果如下所示,其中有些图片网络未能正确识别,还有些对于人眼来说都是模棱两可的,有点太难为机器了。。。

MNIST


  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值