caffe的Net类学习

【注意】这是临时的学习心得,乱得很,莫耽误各位时间。

caffe的Net类(Net.m文件)注释

classdef Net < handle
  % Wrapper class of caffe::Net in matlab
  
  properties (Access = private)
    hNet_self
    attributes
    % attribute fields
    %     hLayer_layers
    %     hBlob_blobs
    %     input_blob_indices
    %     output_blob_indices
    %     layer_names
    %     blob_names
  end
  properties (SetAccess = private)
    layer_vec               %layer 容器
    blob_vec                %blob容器
    inputs
    outputs
    name2layer_index    %layer名称到索引的MAP
    name2blob_index     %blob名称到索引的MAP
    layer_names     %layer名称
    blob_names      %blob名称
    bottom_id_vecs
    top_id_vecs
  end
  
  methods
    function self = Net(varargin)
      % decide whether to construct a net from model_file or handle
      if ~(nargin == 1 && isstruct(varargin{1}))
        % construct a net from model_file
        self = caffe.get_net(varargin{:});
        return
      end
      % construct a net from handle
      hNet_net = varargin{1};
      CHECK(is_valid_handle(hNet_net), 'invalid Net handle');
      
      % setup self handle and attributes
      self.hNet_self = hNet_net;
      self.attributes = caffe_('net_get_attr', self.hNet_self);
      
      % setup layer_vec
      self.layer_vec = caffe.Layer.empty();
      for n = 1:length(self.attributes.hLayer_layers)
        self.layer_vec(n) = caffe.Layer(self.attributes.hLayer_layers(n));
      end
      
      % setup blob_vec
      self.blob_vec = caffe.Blob.empty();
      for n = 1:length(self.attributes.hBlob_blobs);
        self.blob_vec(n) = caffe.Blob(self.attributes.hBlob_blobs(n));
      end
      
      % setup input and output blob and their names
      % note: add 1 to indices as matlab is 1-indexed while C++ is 0-indexed
      self.inputs = ...
        self.attributes.blob_names(self.attributes.input_blob_indices + 1);
      self.outputs = ...
        self.attributes.blob_names(self.attributes.output_blob_indices + 1);
      
      % create map objects to map from name to layers and blobs
      self.name2layer_index = containers.Map(self.attributes.layer_names, ...
        1:length(self.attributes.layer_names));
      self.name2blob_index = containers.Map(self.attributes.blob_names, ...
        1:length(self.attributes.blob_names));
      
      % expose layer_names and blob_names for public read access
      self.layer_names = self.attributes.layer_names;
      self.blob_names = self.attributes.blob_names;
      
      % expose bottom_id_vecs and top_id_vecs for public read access
      self.attributes.bottom_id_vecs = cellfun(@(x) x+1, self.attributes.bottom_id_vecs, 'UniformOutput', false);
      self.bottom_id_vecs = self.attributes.bottom_id_vecs;
      self.attributes.top_id_vecs = cellfun(@(x) x+1, self.attributes.top_id_vecs, 'UniformOutput', false);
      self.top_id_vecs = self.attributes.top_id_vecs;
    end
    function set_phase(self, phase_name) 
        %设置使用模式,“train” 或者 “test”
      CHECK(ischar(phase_name), 'phase_name must be a string');
      CHECK(strcmp(phase_name, 'train') || strcmp(phase_name, 'test'), ...
            sprintf('phase_name can only be %strain%s or %stest%s', ...
            char(39), char(39), char(39), char(39)));
      caffe_('net_set_phase', self.hNet_self, phase_name);
    end 
    function share_weights_with(self, net)
        %共享参数到另一个模型net
      CHECK(is_valid_handle(net.hNet_net), 'invalid Net handle');
      caffe_('net_share_trained_layers_with', self.hNet_net, net.hNet_net);
    end
    function layer = layers(self, layer_name)
       %通过layer名称获得该层
      CHECK(ischar(layer_name), 'layer_name must be a string');
      layer = self.layer_vec(self.name2layer_index(layer_name));
    end
    function blob = blobs(self, blob_name)
        %通过blob名称获得该blob
      CHECK(ischar(blob_name), 'blob_name must be a string');
      blob = self.blob_vec(self.name2blob_index(blob_name));
    end
    function blob = params(self, layer_name, blob_index)
        %通过layer名称与该层中某个blob的索引,获得该blob的参数
      CHECK(ischar(layer_name), 'layer_name must be a string');
      CHECK(isscalar(blob_index), 'blob_index must be a scalar');
      blob = self.layer_vec(self.name2layer_index(layer_name)).params(blob_index);
    end
    function set_params_data(self, layer_name, blob_index, data)
        %设置名称为layer_name的layer下索引为blob_index的blob的数值为data,
      CHECK(ischar(layer_name), 'layer_name must be a string');
      CHECK(isscalar(blob_index), 'blob_index must be a scalar');
      self.layer_vec(self.name2layer_index(layer_name)).set_params_data(blob_index, data);
    end
    function forward_prefilled(self)
        %前向预填充
      caffe_('net_forward', self.hNet_self);
    end
    function backward_prefilled(self)
        %反向预填充
      caffe_('net_backward', self.hNet_self);
    end
    function set_input_data(self, input_data)
        %设置输入数据
      CHECK(iscell(input_data), 'input_data must be a cell array');
      CHECK(length(input_data) == length(self.inputs), ...
        'input data cell length must match input blob number');
      % copy data to input blobs
      for n = 1:length(self.inputs)
        self.blobs(self.inputs{n}).set_data(input_data{n});
      end
    end
    function res = get_output(self)
      % get onput blobs 得到输出
      res = struct('blob_name', '', 'data', []);
      for n = 1:length(self.outputs)
        res(n).blob_name = self.outputs{n};
        res(n).data = self.blobs(self.outputs{n}).get_data();
      end
    end
    function res = forward(self, input_data)
        %设定输入数据,然后得到输出
      CHECK(iscell(input_data), 'input_data must be a cell array');
      CHECK(length(input_data) == length(self.inputs), ...
        'input data cell length must match input blob number');
      % copy data to input blobs
      for n = 1:length(self.inputs)
        if isempty(input_data{n})
            continue;
        end
        self.blobs(self.inputs{n}).set_data(input_data{n});
      end
      self.forward_prefilled();
      % retrieve data from output blobs
      res = cell(length(self.outputs), 1);
      for n = 1:length(self.outputs)
        res{n} = self.blobs(self.outputs{n}).get_data();
      end
    end
    function res = backward(self, output_diff)
        %反向传播
      CHECK(iscell(output_diff), 'output_diff must be a cell array');
      CHECK(length(output_diff) == length(self.outputs), ...
        'output diff cell length must match output blob number');
      % copy diff to output blobs
      for n = 1:length(self.outputs)
        self.blobs(self.outputs{n}).set_diff(output_diff{n});
      end
      self.backward_prefilled();
      % retrieve diff from input blobs
      res = cell(length(self.inputs), 1);
      for n = 1:length(self.inputs)
        res{n} = self.blobs(self.inputs{n}).get_diff();
      end
    end
    function copy_from(self, weights_file)
        %将模型参数载入网络
    %%% weights_file =D:\yangyun_faster_RCNN\faster_rcnn-master\output\faster_rcnn_final\faster_rcnn_VOC0712_ZF\proposal_final
      CHECK(ischar(weights_file), 'weights_file must be a string');
      CHECK_FILE_EXIST(weights_file);
      caffe_('net_copy_from', self.hNet_self, weights_file);
    end


    function reshape(self)
        %改造形状
      caffe_('net_reshape', self.hNet_self);
    end
    function reshape_as_input(self, input_data)
        %改造形状作为输入
      CHECK(iscell(input_data), 'input_data must be a cell array');
      CHECK(length(input_data) == length(self.inputs), ...
        'input data cell length must match input blob number');
      % reshape input blobs
      for n = 1:length(self.inputs)
        if isempty(input_data{n})
            continue;
        end
        input_data_size = size(input_data{n});
        input_data_size_extended = [input_data_size, ones(1, 4 - length(input_data_size))];
        self.blobs(self.inputs{n}).reshape(input_data_size_extended);
      end
      self.reshape();
    end
    function save(self, weights_file)
        %保存模型参数到本地文件
      CHECK(ischar(weights_file), 'weights_file must be a string');
      caffe_('net_save', self.hNet_self, weights_file);
    end
  end
end


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值