看MoCo的论文会看到以下代码
# positice logit: Nx1
l_pos=bmm(q.view(N,1,C),k.view(N,C,1))
# negitice logits: NxK
l_neg=mm(q.view(N,C),queue.view(C,K))
# logits: Nx(1+K)
logits=cat([l_pos,l_neg],dim=1)
#contrastive loss, Eqn.(1)
labels=zeros(N)# postives are the 0-th
loss=CrossEntropyLoss(logits/t,labels)
这前面三步很好理解,后面两句是什么意思呢?
先看CrossEntropyLoss的pytorch文档解释
The loss can be described as:
loss
(
x
,
c
l
a
s
s
)
=
−
log
(
exp
(
x
[
c
l
a
s
s
]
)
∑
j
exp
(
x
[
j
]
)
)
=
−
x
[
c
l
a
s
s
]
+
log
(
∑
j
exp
(
x
[
j
]
)
)
\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) = -x[class] + \log\left(\sum_j \exp(x[j])\right)
loss(x,class)=−log(∑jexp(x[j])exp(x[class]))=−x[class]+log(j∑exp(x[j]))
带入logits/t
和labels
,对于logits/t
某
i
i
i行,labels
某
i
i
i个数据来说
loss
(
l
o
g
i
t
s
(
i
)
/
t
,
l
a
b
e
l
s
(
i
)
)
=
−
log
(
exp
(
l
o
g
i
t
s
(
i
)
[
l
a
b
e
l
s
(
i
)
]
/
t
)
∑
j
=
0
N
exp
(
l
o
g
i
t
s
(
i
)
[
j
]
/
t
)
)
\text{loss}(logits^{(i)}/t, labels^{(i)}) = -\log\left(\frac{\exp(logits^{(i)}[labels^{(i)}]/t)}{\sum_{j=0}^N \exp(logits^{(i)}[j]/t)}\right)
loss(logits(i)/t,labels(i))=−log(∑j=0Nexp(logits(i)[j]/t)exp(logits(i)[labels(i)]/t))
考虑labels=zeros(N)
loss
(
l
o
g
i
t
s
(
i
)
/
t
,
0
)
=
−
log
(
exp
(
l
o
g
i
t
s
(
i
)
[
0
]
/
t
)
∑
j
=
0
N
exp
(
l
o
g
i
t
s
(
i
)
[
j
]
/
t
)
)
\text{loss}(logits^{(i)}/t,0) = -\log\left(\frac{\exp(logits^{(i)}[0]/t)}{\sum_{j=0}^N \exp(logits^{(i)}[j]/t)}\right)
loss(logits(i)/t,0)=−log(∑j=0Nexp(logits(i)[j]/t)exp(logits(i)[0]/t))
而
l
o
g
i
t
s
(
i
)
[
0
]
=
q
k
+
,
l
o
g
i
t
s
(
i
)
[
j
]
=
q
k
j
logits^{(i)}[0]=qk_+,logits^{(i)}[j]=qk_j
logits(i)[0]=qk+,logits(i)[j]=qkj
loss
(
l
o
g
i
t
s
(
i
)
/
t
,
0
)
=
−
log
(
exp
(
q
k
+
/
τ
)
∑
j
=
0
N
exp
(
q
k
j
/
t
)
)
\text{loss}(logits^{(i)}/t,0) = -\log\left(\frac{\exp(qk_+/\tau)}{\sum_{j=0}^N \exp(qk_j/t)}\right)
loss(logits(i)/t,0)=−log(∑j=0Nexp(qkj/t)exp(qk+/τ))