给出代码地址:https://github.com/kastnerkyle/deform-conv,keras版本的。
可以直接看目录scripts下的scaled_mnist.py,网络模型由函数get_deform_cnn()加载:
# ---
# Deformable CNN
inputs, outputs = get_deform_cnn(trainable=False)
model = Model(inputs=inputs, outputs=outputs)
get_deform_cnn()定义在目录deform_conv下的cnn.py中,整体就是一个普通的cnn网络,只不过卷积前加了ConvOfffset2D:
def get_deform_cnn(trainable):
inputs = l = Input((28, 28, 1), name='input')
# conv11
l = Conv2D(32, (3, 3), padding='same', name='conv11', trainable=trainable)(l)
l = Activation('relu', name='conv11_relu')(l)
l = BatchNormalization(name='conv11_bn')(l)
# conv12
l_offset = ConvOffset2D(32, name='conv12_offset')(l)
l = Conv2