import torch
def batch_euclidean_dist(x, y):
"""
Args:
x: pytorch Variable, with shape [Batch size, Local part, Feature channel]
y: pytorch Variable, with shape [Batch size, Local part, Feature channel]
Returns:
dist: pytorch Variable, with shape [Batch size, Local part, Local part]
"""
assert len(x.size()) == 3
assert len(y.size()) == 3
assert x.size(0) == y.size(0)
assert x.size(-1) == y.size(-1)
N, m, d = x.size()
N, n, d = y.size()
xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n)
yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1)
dist = xx + yy
dist.baddbmm_( x, y.permute(0, 2, 1), beta = 1, alpha = -2)
dist = dist.clamp(min=1e-12).sqrt()
return dist
def shortest_dist(dist_mat):
"""Parallel version.
Args:
dist_mat: pytorch Variable, available shape:
1) [m, n]
2) [m, n, N], N is batch size
3) [m, n, *], * can be arbitrary additional dimensions
Returns:
dist: three cases corresponding to `dist_mat`:
1) scalar
2) pytorch Variable, with shape [N]
3) pytorch Variable, with shape [*]
"""
m, n = dist_mat.size()[:2]
dist = [[0 for _ in range(n)] for _ in range(m)]
for i in range(m):
for j in range(n):
if (i == 0) and (j == 0):
dist[i][j] = dist_mat[i, j]
elif (i == 0) and (j > 0):
dist[i][j] = dist[i][j - 1] + dist_mat[i, j]
elif (i > 0) and (j == 0):
dist[i][j] = dist[i - 1][j] + dist_mat[i, j]
else:
dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j]
dist = dist[-1][-1]
return dist
def hard_example_mining(dist_mat, labels, return_inds=False):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
return_inds: whether to return the indices. Save time if `False`(?)
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
dist_an, relative_n_inds = torch.min(
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
ind = (labels.new().resize_as_(labels)
.copy_(torch.arange(0, N).long())
.unsqueeze( 0).expand(N, N))
p_inds = torch.gather(
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
n_inds = torch.gather(
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an
def euclidean_dist(x, y):
"""
Args:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
Returns:
dist: pytorch Variable, with shape [m, n]
"""
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, x, y.t())
dist = dist.clamp(min=1e-12).sqrt()
return dist
def batch_local_dist(x, y):
"""
Args:
x: pytorch Variable, with shape [N, m, d]
y: pytorch Variable, with shape [N, n, d]
Returns:
dist: pytorch Variable, with shape [N]
"""
assert len(x.size()) == 3
assert len(y.size()) == 3
assert x.size(0) == y.size(0)
assert x.size(-1) == y.size(-1)
dist_mat = batch_euclidean_dist(x, y)
dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)
dist = shortest_dist(dist_mat.permute(1, 2, 0))
return dist
if __name__ == '__main__':
x = torch.randn(32, 8, 2048)
y = torch.randn(32, 8, 2048)
local_dist = batch_local_dist(x, y)
from IPython import embed
embed()
tensor([15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15.,
15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15.,
15., 15., 15., 15.])