Definition
该损失函数的提出最早在这篇论文中,主要用来做降维处理。链接
caffe中siamese网络用到Contrastive loss,这种损失函数可以有效地处理siamese网络中pair-data的数据关系。
Formula
Lc=12N∑n=1Nyd+(1−y)max(margin−d,0),d=∥an−bn∥2
L
c
=
1
2
N
∑
n
=
1
N
y
d
+
(
1
−
y
)
m
a
x
(
m
a
r
g
i
n
−
d
,
0
)
,
d
=
‖
a
n
−
b
n
‖
2
当y=1时,表示两张图片相似,表达式只剩下前面部分 Lc=12N∑Nn=1yd L c = 1 2 N ∑ n = 1 N y d ;y=0时,表示两张图片不相似,即表达式后部分 Lc=12N∑Nn=1max(margin−d,0) L c = 1 2 N ∑ n = 1 N m a x ( m a r g i n − d , 0 ) 。
Code
输入:bottom[0],bottom[1],bottom[2]
特征a: (N*C*1*1)
特征b: (N*C*1*1)
相似性y: (N*1*1*1)
输出:top[0] -> (1*1*1*1)
(1)LayerSetUp
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::LayerSetUp(bottom, top);
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->height(), 1);
CHECK_EQ(bottom[0]->width(), 1);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
CHECK_EQ(bottom[2]->channels(), 1);
CHECK_EQ(bottom[2]->height(), 1);
CHECK_EQ(bottom[2]->width(), 1);
diff_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1);
diff_sq_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1);
dist_sq_.Reshape(bottom[0]->num(), 1, 1, 1);//y
// vector of ones used to sum along channels
summer_vec_.Reshape(bottom[0]->channels(), 1, 1, 1);
for (int i = 0; i < bottom[0]->channels(); ++i)
summer_vec_.mutable_cpu_data()[i] = Dtype(1);
}
(2) Forward前向传播
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
int count = bottom[0]->count();
caffe_sub(
count,
bottom[0]->cpu_data(), //特征a
bottom[1]->cpu_data(), //特征b
diff_.mutable_cpu_data());// an-bn
const int channels = bottom[0]->channels();
//定义参数margin
Dtype margin = this->layer_param_.contrastive_loss_param().margin();
bool legacy_version =
this->layer_param_.contrastive_loss_param().legacy_version();
//初始化loss
Dtype loss(0.0);
for (int i = 0; i < bottom[0]->num(); ++i) {
//公式中的距离d=||an-bn||2
//caffe_cpu_dot做点乘运算:
//dist_sq_.mutable=dist_sq_.mutable*diff_.cpu_data*diff_.cpu_data
dist_sq_.mutable_cpu_data()[i] = caffe_cpu_dot(channels,
diff_.cpu_data() + (i*channels), diff_.cpu_data() + (i*channels));
if (static_cast<int>(bottom[2]->cpu_data()[i])) {
// similar pairs
loss += dist_sq_.cpu_data()[i];
// dissimilar pairs
if (legacy_version) {
loss += std::max(margin - dist_sq_.cpu_data()[i], Dtype(0.0));
} else {
Dtype dist = std::max<Dtype>(margin - sqrt(dist_sq_.cpu_data()[i]),
Dtype(0.0));
loss += dist*dist;
}
}
}
//计算loss:即公式
//将loss作为top[0]输出
loss = loss / static_cast<Dtype>(bottom[0]->num()) / Dtype(2);
top[0]->mutable_cpu_data()[0] = loss;
}
(2) Backward反向传播
template <typename Dtype>
void ContrastiveLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
Dtype margin = this->layer_param_.contrastive_loss_param().margin();
bool legacy_version =
this->layer_param_.contrastive_loss_param().legacy_version();
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] /
static_cast<Dtype>(bottom[i]->num());
int num = bottom[i]->num();
int channels = bottom[i]->channels();
for (int j = 0; j < num; ++j) {
Dtype* bout = bottom[i]->mutable_cpu_diff();
if (static_cast<int>(bottom[2]->cpu_data()[j])) { // similar pairs
caffe_cpu_axpby(
channels,
alpha,
diff_.cpu_data() + (j*channels),
Dtype(0.0),
bout + (j*channels));
} else { // dissimilar pairs
Dtype mdist(0.0);
Dtype beta(0.0);
if (legacy_version) {
mdist = margin - dist_sq_.cpu_data()[j];
beta = -alpha;
} else {
Dtype dist = sqrt(dist_sq_.cpu_data()[j]);
mdist = margin - dist;
beta = -alpha * mdist / (dist + Dtype(1e-4));
}
if (mdist > Dtype(0.0)) {
caffe_cpu_axpby(
channels,
beta,
diff_.cpu_data() + (j*channels),
Dtype(0.0),
bout + (j*channels));
} else {
caffe_set(channels, Dtype(0), bout + (j*channels));
}
}
}
}
}
}
Usage
layer {
name: "contrastive_loss"
type: "ContrastiveLoss"
bottom: "feat_1" //特征a
bottom: "feat_2" //特征b
bottom: "sim" //相似度y
top: "contrastive_loss"//输出loss
contrastive_loss_param {
margin: 1 //参数margin
}
}
如果输入数据是data a, data b,以及相似度y,在全连接层的输出后面加一个slice 层,将特征切分为feat_1,feat_2即可。
如果输入数据是data a, data b,以及各自的label,需要添加一个层输出label的相似度y。下面我们以Identity2VerifyLayer为分析:(本质就是slice层)
template <typename Dtype>
void Identity2VerifyLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
}
template <typename Dtype>
void Identity2VerifyLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
vector<int> top_shape = bottom[0]->shape();
top_shape[0] /= 2;
top[0]->Reshape(top_shape);
top[1]->Reshape(top_shape);
vector<int> label_shape = bottom[1]->shape();
label_shape[0] /= 2;
top[2]->Reshape(label_shape);
}
template <typename Dtype>
void Identity2VerifyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int feature_size = bottom[0]->count(1);
for (int n = 0; n < bottom[0]->num(); ++ n) {
caffe_copy(
feature_size,
bottom[0]->cpu_data() + n * feature_size,
top[n & 1]->mutable_cpu_data() + (n / 2) * feature_size
);
}
const int label_size = bottom[1]->count(1);
for (int n = 0; n < bottom[1]->num(); n += 2) {
Dtype label;
//similar pairs
//abs绝对值运算
if( abs(*(bottom[1]->cpu_data() + n * label_size) -
*(bottom[1]->cpu_data() + (n+1) * label_size)) < 1)
{
label = Dtype(1.0);//sim=1
}
else
{
label= Dtype(0.0);//sim=0
}
caffe_copy(
label_size,
&label,
top[2]->mutable_cpu_data() + (n / 2) * label_size
);
}
}
template <typename Dtype>
void Identity2VerifyLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const int feature_size = bottom[0]->count(1);
for (int n = 0; n < bottom[0]->num(); ++ n) {
caffe_copy(
feature_size,
top[n & 1]->cpu_diff() + (n / 2) * feature_size,
bottom[0]->mutable_cpu_diff() + n * feature_size
);
}
layer {
name: "slice"
type: "Identity2Verify"
bottom: "ip2"
bottom: "label"
top: "feat_1"//top[0]
top: "feat_2"//top[1]
top: "sim" //top[2]
slice_param {
axis: 0
}
}