Siamese网络
Few Shot Learning
首先看一个例子:
上图有一个Support set,其中的数据很少,不足以支持神经网络的训练。现在我们有一个Query,其类型是 Support set 中的一类。那么我们怎么通过 Support set 找到 Query 的所属分类。当然了我们人一眼就能分辨。但是怎么让机器去识别呢?
Few Shot Learning 可以解决上面这种情况,其通过很少的样本做分类或回归。
看两个 Few Shot Learning 中的几个概念:
k
k
k-way
n
n
n-shot Support Set,其中
k
k
k 是 Support Set 中的类别个数,
n
n
n 是每个类别中的样本数,上图是一个
3
3
3-way
2
2
2-shot Support Set。
Query 要辨别的目标。
解决上面问题的思路就是:训练一个神经网络其可以计算出 Query 与 Support Set 中
k
k
k 个类别的之间的相似度。与Query相似度越高的类别,是Query类别的概率越大。
值得注意的是:Support set 中的数据并不出现在训练集中。
Siamese网络介绍
Siamese网络可以通过训练,计算出数据之间的相似度。 Siamese译为‘孪生’,‘连体’。
通过一个函数(或神经网络)将输入映射到目标空间,在目标空间使用简单的距离(欧式距离等)进行相似度对比。在训练阶段最小化来自相同类别的一对样本的损失函数值,最大化来自不同类别的一对样本的损失函数值。
如下图:
上图中的红色框中的网络结构相同。
训练神经网络
处理训练集
训练上面的siamese神经网络首先要有数据集。我们处理训练数据如下图所示:
同一类别的选两张组合在一起,称为正样本,标签为1。
不同一类别的选两张组合在一起,称为负样本,标签为0。
训练网络
首先,用siamese网络分别提取
X
1
X1
X1 和
X
2
X2
X2 的特征
h
1
h1
h1 和
h
2
h2
h2,然后计算
z
=
∣
h
1
−
h
2
∣
z=|h1-h2|
z=∣h1−h2∣,
z
z
z 表示两个特征之间的区别。然后使用Dense layers 处理
z
z
z, 得到一个标量,然后用sigmoid激活函数做分类。
简单代码实现:
结果可以看下图所示:
上图中 Query 与Support Set 中的兔子最相似,那么Query的类别就是兔子。
Triplet Loss
上面介绍的一种siamese网络是双连体,那么我们也可以设计成三连体,其训练方式称为Triplet Loss。如下图:
Triplet Loss 训练数据集处理
训练网络![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/4e49eb4a4d1d120686a8cd1ddc12aaa2.png)
根据上面我们定义损失函数:
损失函数应鼓励最小化
d
+
d^+
d+,即最小化同类别之间的距离;鼓励最大化
d
−
d^-
d−,即最大化不同类别之间的距离。
我们指定一个margin,如果
d
−
>
=
d
+
+
m
a
r
g
i
n
d^->=d^++margin
d−>=d++margin ,我们认为这组数据分类正确,
L
o
s
s
=
0
Loss=0
Loss=0,如果条件不满足,说明分不开正负样本,则
L
o
s
s
=
d
−
−
d
+
+
m
a
r
g
i
n
Loss = d^--d^++margin
Loss=d−−d++margin。综上
L
o
s
s
=
m
a
x
{
0
,
d
−
−
d
+
+
m
a
r
g
i
n
}
Loss =max\{ 0,d^--d^++margin\}
Loss=max{0,d−−d++margin}
简单代码实现:
结果比较如图:
上图中 Query 与Support Set 中的兔子距离最小,那么Query的类别就是兔子。