FCN中的transplant

FCN中的surgery.transplant函数用于拷贝learnable参数,其直接目的是:将VGG分类模型中的一些全连接层的参数正确地拷贝到相应的目标全连接层中。代码如下:

def transplant(new_net, net, suffix=''):
    """
    Transfer weights by copying matching parameters, coercing parameters of
    incompatible shape, and dropping unmatched parameters.

    The coercion is useful to convert fully connected layers to their
    equivalent convolutional layers, since the weights are the same and only
    the shapes are different.  In particular, equivalent fully connected and
    convolution layers have shapes O x I and O x I x H x W respectively for O
    outputs channels, I input channels, H kernel height, and W kernel width.

    Both  `net` to `new_net` arguments must be instantiated `caffe.Net`s.
    """
    for p in net.params:
        p_new = p + suffix
        if p_new not in new_net.params:
            print 'dropping', p
            continue
        for i in range(len(net.params[p])):
            if i > (len(new_net.params[p_new]) - 1):
                print 'dropping', p, i
                break
            if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape:
                print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape
            else:
                print 'copying', p, ' -> ', p_new, i
            new_net.params[p_new][i].data.flat = net.params[p][i].data.flat

ndarray.flat返回flatiter对象,即
这里写图片描述
surgery.transplant的调用方式在solve.py中:surgery.transplant(solver.net,vgg_net)

import sys    
sys.path.append('/home/my/caffe-master/caffe-master/python')  
import caffe  
import surgery, score  

import numpy as np  
import os  
import sys  

try:  
    import setproctitle  
    setproctitle.setproctitle(os.path.basename(os.getcwd()))  
except:  
    pass  

vgg_weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'  
vgg_proto = '../ilsvrc-nets/VGG_ILSVRC_16_layers_deploy.prototxt'  
weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'  

# init  
caffe.set_mode_gpu()  
# caffe.set_device(int(sys.argv[0]))  
caffe.set_device(7)  

#solver = caffe.SGDSolver('solver.prototxt')  
#solver.net.copy_from(weights)  
solver = caffe.SGDSolver('solver.prototxt')  
vgg_net=caffe.Net(vgg_proto,vgg_weights,caffe.TRAIN)  
surgery.transplant(solver.net,vgg_net)  
del vgg_net  

# surgeries  
interp_layers = [k for k in solver.net.params.keys() if 'up' in k]  
surgery.interp(solver.net, interp_layers)  

# scoring  
test = np.loadtxt('../data/sift-flow/test.txt', dtype=str)  

for _ in range(50):  
    solver.step(2000)  
    # N.B. metrics on the semantic labels are off b.c. of missing classes;  
    # score manually from the histogram instead for proper evaluation  
    score.seg_tests(solver, False, test, layer='score_sem', gt='sem')  
    score.seg_tests(solver, False, test, layer='score_geo', gt='geo') 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值