# 3. 网络权重初始化问题

solver = caffe.SGDSolver('solver.prototxt')
solver.net.copy_from(weights)

solver = caffe.SGDSolver('solver.prototxt')
# 这里开始的3行都是我们需要增加的
vgg_net = caffe.Net(vgg_proto,vgg_weights,caffe.TRAIN)
surgery.transplant(solver.net,vgg_net)
del vgg_net  

def transplant(new_net, net, suffix=''):
"""
Transfer weights by copying matching parameters, coercing parameters of
incompatible shape, and dropping unmatched parameters.

The coercion is useful to convert fully connected layers to their
equivalent convolutional layers, since the weights are the same and only
the shapes are different.  In particular, equivalent fully connected and
convolution layers have shapes O x I and O x I x H x W respectively for O
outputs channels, I input channels, H kernel height, and W kernel width.

Both  net to new_net arguments must be instantiated caffe.Nets.
"""
for p in net.params:
p_new = p + suffix
if p_new not in new_net.params:
print 'dropping', p
continue
for i in range(len(net.params[p])):
if i > (len(new_net.params[p_new]) - 1):
print 'dropping', p, i
break
if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape:
print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape
else:
print 'copying', p, ' -> ', p_new, i
new_net.params[p_new][i].data.flat = net.params[p][i].data.flat

# 4. 反卷积层初始化问题

loss不下降，开始训练什么样子，最后还是什么样子

# surgeries
interp_layers = [k for k in solver.net.params.keys() if 'up' in k]
surgery.interp(solver.net, interp_layers)  

def interp(net, layers):
"""
Set weights of each layer in layers to bilinear kernels for interpolation.
"""
for l in layers:
m, k, h, w = net.params[l][0].data.shape
if m != k and k != 1:
print 'input + output channels need to be the same or |output| == 1'
raise
if h != w:
print 'filters need to be square'
raise
filt = upsample_filt(h)
net.params[l][0].data[range(m), range(k), :, :] = filt

def upsample_filt(size):
"""
Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size.
"""
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
return (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)