1.匹配问题(mnist为例)
code:https://github.com/ywpkwon/siamese_tf_mnist
网络结构:
训练时:
通过反复迭代最小化损失函数,训练模型
测试时:
用训练好的模型测试单张图片,得到输出y是二维向量。然后以该输出作为坐标,在该坐标展示该图片,例如输出[2.4565, -0.46545],该图片对应的标签是3。最后得到展示的图片:
可以看出类内距离小,类间距离远。
2.目标跟踪问题(DaSiamRPN为例)
code:https://github.com/foolwood/DaSiamRPN
该代码没有给出训练过程,测试过程如下。
程序中具体实现过程
1.第一帧 (360,480,3),然后抠出来(127,127,3),抠出的图目标作为模板图像,然后经过网络输出,得到输出(1,256,6,6),然后再卷积得到对应的两个输出self.r1_kernel(20,256,4,4)和self.cls1_kernel(10,256,4,4)
2.输入第i(i>1)帧图像,从(360,480,3)中抠出(271,271,3)的图,然后把抠出的(271,271,3)的图输入到网络中,得到输出(1,256,24,24),然后再卷积得到对应的两个输出self.conv_r2(1,256,22,22)和self.conv_cls2(1,256,22,22)。
3.用self.r1_kernel(20,256,4,4)去卷积self.conv_r2得到输出delta(1,20,19,19),用和self.cls1_kernel(10,256,4,4)去卷积self.conv_cls2得到score(1,10,19,19)
4.把delta(1,20,19,19)变换成(4,1805),把score(1,10,19,19)变换成(2,1805),再经过softmax得到(1805,)。根据score通过np.argmax找到分值最大的位置,进而得到预测的类别的bbox。
5.循环输入图片。重复步骤2-4。