Caffe常用python接口记录

1. python接口训练

caffe python下网络的训练

import caffe

caffe.set_device(int(0))
caffe.set_mode_gpu()  # GPU
caffe.set_mode_cpu()  # CPU

solver = caffe.SGDSolver('solver.prototxt') # 指定SGD solver,可换为其它求解算法
# 方式1
solver.solve()  # 会进行完整的梯度训练,直至在solver中规定的max_iter
# 方式2
n_iter = 1000
n_step = 10
for _ in range(n_iter):
	solver.step(n_step)  # 进行完整的n_step次计算(minibatch)(包括数据的前向传播,误差反向传播,以及网络权值的update)

个人比较喜欢使用上面的第二种方式,因为可以在中间进行测试,详细记录网络在训练过程中的变化。此外,从训练的中间过程中恢复使用如下的方式:

import caffe

caffe.set_device(int(0))
caffe.set_mode_gpu()

solver = caffe.SGDSolver('solver.prototxt')
solver.restore('_iter_5000.solverstate')
solver.solve()

2. python测试接口

import caffe
import cv2

deploy = 'deploy.prototxt'    #deploy文件
caffe_model = '_iter_9380.caffemodel'   #训练好的 caffemodel
net = caffe.Net(net_file, caffe_model, caffe.TEST)   #加载model和network

in_ = cv2.imread('test.jpg')
# other preprocessing
......

# shape for input (data blob is N x C x H x W), set data
net.blobs['data'].reshape(1, *in_.shape)
net.blobs['data'].data[...] = in_

# run net and take argmax for prediction
net.forward()

# extract output blob
......

3. python下网络数据抽取

net = caffe.Net(net_file, caffe_model, caffe.TEST)   #加载model和network

上面的这行代码就把所有的参数和数据都加载到一个net变量里面了,但是net是一个很复杂的object, 想直接显示出来看是不行的。其中:

  • net.params: 保存各层的参数值(w和b)
  • net.blobs: 保存各层的数据值,forward之后才会有

对于网络中的权值参数可以使用访问字典形式进行访问:

# 1.
[(k,v[0].data) for k,v in net.params.items()]
# 2.
conv1_w = net.params['conv1'].data[0]

查看各层的参数值,其中k表示层的名称,v[0].data就是各层的W值,而v[1].data是各层的b值。 这里需要注意的是:上面取出的参数在训练过程中是可读可写的,不管是GPU还是CPU环境下,这样就为某些用途提供了便利,比如剪裁。。。
除了查看参数,我们还可以查看数据,但是要注意的是,net里面刚开始是没有数据的,需要运行:

net.forward()

之后才会有数据。我们可以用下面的代码去访问:

# 1.
[(k,v.data) for k,v in net.blobs.items()]
# 2.
layer_data = net.blobs['conv1'].data

4. snapshot管理

caffe的python接口支持的中间模型文件保存方式有2种:

  • 方法一:
    solver.snapshot() 可以在训练过程中,手动进行snapshot,它会保存 .caffenodel与 .solverstate两个文件;常用于进行恢复训练过程;(保存的路径为solver.prototxt 文件里面定义的路径);

  • 方法二:
    net.save(), 它只会保存一下 .caffemodel文件,常用于进行测试时。 使用方法:如, net. save(‘my_path/my_weights.caffemodel’);

5. 网络finetune

  • 方法一:直接从现有的caffemodel里面进行拷贝
my_solver = caffe.get_solver(net_solver.prototxt)
my_solver.net.copy_from('pretraind.caffemodel')
  • 方法二:上述的方法在一些维度不匹配的时候会报错,但是预训练网络的其它一些参数还是可以使用的,这就需要使用下面的代码:
# source: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py
def transplant(new_net, net, suffix=''):
    """
    Transfer weights by copying matching parameters, coercing parameters of
    incompatible shape, and dropping unmatched parameters.
    The coercion is useful to convert fully connected layers to their
    equivalent convolutional layers, since the weights are the same and only
    the shapes are different.  In particular, equivalent fully connected and
    convolution layers have shapes O x I and O x I x H x W respectively for O
    outputs channels, I input channels, H kernel height, and W kernel width.
    Both  `net` to `new_net` arguments must be instantiated `caffe.Net`s.
    """
    for p in net.params:
        p_new = p + suffix
        if p_new not in new_net.params:
            print 'dropping', p
            continue
        for i in range(len(net.params[p])):
            if i > (len(new_net.params[p_new]) - 1):
                print 'dropping', p, i
                break
            if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape:
                print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape
            else:
                print 'copying', p, ' -> ', p_new, i
            new_net.params[p_new][i].data.flat = net.params[p][i].data.flat
            
vgg_net = caffe.Net('deploy_21.prototxt', weights, caffe.TRAIN) # VGG net
transplant(solver.net, vgg_net)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值