FCN中的surgery.transplant函数用于拷贝learnable参数,其直接目的是:将VGG分类模型中的一些全连接层的参数正确地拷贝到相应的目标全连接层中。代码如下:
def transplant(new_net, net, suffix=''):
"""
Transfer weights by copying matching parameters, coercing parameters of
incompatible shape, and dropping unmatched parameters.
The coercion is useful to convert fully connected layers to their
equivalent convolutional layers, since the weights are the same and only
the shapes are different. In particular, equivalent fully connected and
convolution layers have shapes O x I and O x I x H x W respectively for O
outputs channels, I input channels, H kernel height, and W kernel width.
Both `net` to `new_net` arguments must be instantiated `caffe.Net`s.
"""
for p in net.params:
p_new = p + suffix
if p_new not in new_net.params:
print 'dropping', p
continue
for i in range(len(net.params[p])):
if i > (len(new_net.params[p_new]) - 1):
print 'dropping', p, i
break
if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape:
print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape
else:
print 'copying', p, ' -> ', p_new, i
new_net.params[p_new][i].data.flat = net.params[p][i].data.flat
ndarray.flat返回flatiter对象,即
surgery.transplant的调用方式在solve.py中:surgery.transplant(solver.net,vgg_net)
import sys
sys.path.append('/home/my/caffe-master/caffe-master/python')
import caffe
import surgery, score
import numpy as np
import os
import sys
try:
import setproctitle
setproctitle.setproctitle(os.path.basename(os.getcwd()))
except:
pass
vgg_weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'
vgg_proto = '../ilsvrc-nets/VGG_ILSVRC_16_layers_deploy.prototxt'
weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'
# init
caffe.set_mode_gpu()
# caffe.set_device(int(sys.argv[0]))
caffe.set_device(7)
#solver = caffe.SGDSolver('solver.prototxt')
#solver.net.copy_from(weights)
solver = caffe.SGDSolver('solver.prototxt')
vgg_net=caffe.Net(vgg_proto,vgg_weights,caffe.TRAIN)
surgery.transplant(solver.net,vgg_net)
del vgg_net
# surgeries
interp_layers = [k for k in solver.net.params.keys() if 'up' in k]
surgery.interp(solver.net, interp_layers)
# scoring
test = np.loadtxt('../data/sift-flow/test.txt', dtype=str)
for _ in range(50):
solver.step(2000)
# N.B. metrics on the semantic labels are off b.c. of missing classes;
# score manually from the histogram instead for proper evaluation
score.seg_tests(solver, False, test, layer='score_sem', gt='sem')
score.seg_tests(solver, False, test, layer='score_geo', gt='geo')