AlignedReID 源码解析

写东西最好中午写,因为早晚不想写

一、论文回顾

论文获取:http://AlignedReID: Surpassing Human-Level Performance in Person Re-Identification

AlignedReID很有意思,提出了动态规划下计算局部最短路径的方法。这条最短路径中的一条边就对应了一对局部特征的匹配,它给出了一种人体对齐的方式,在保证身体个部分相对顺序正确的情况下,这种对齐方式的总距离是最短的。在训练的时候,最短路径的长度被加入到局部损失函数,辅助学习行人的整体特征。

要看懂这张图,我们分几个小问题来一步步分析。

1. 7x7的矩阵是怎么得到的?——用F和G分别代表两张图,其特征分别为\mathbf{F}=\left\{\boldsymbol{f}_{1}, \boldsymbol{f}_{2}, \ldots, \boldsymbol{f}_{H}\right\}\mathrm{G}=\left\{\boldsymbol{g}_{1}, \boldsymbol{g}_{2}, \ldots, \boldsymbol{g}_{H}\right\},在文中H=7,分别计算各部分的相似度(根据公式一),得出7x7=49种相似度,即一个7x7的矩阵。

2. 拐点的含义是什么?——图中拐点表示两张图像切片相对应的位置,比如(2,4)表示Image A的第二块切片和Image B的第四块切片是Aligned

3. 如何理解最短路径?——分黑线和黑箭头两部分理解。首先,第一行黑线横跨矩阵中1-4行,表示Image A的第一块切片与Image B的1-4块切片是corresponding alignment;其次,第一行黑箭头同样横跨矩阵中1-4行,表示Image A的第一块切片与整个Image B的最大相似度(或最短距离),换句话说就是计算机认为A的第一块切片只于B的前四块切片相似,与后面三块完全不相似,所以剩下的三块不属于对应对齐区域,也就无关最短路径的计算。

4. 黑线的最短路径计算公式是什么?(你会说是公式2,没有错,那你能自己推导吗)

Image A的第一块切片与Image B的距离:d_{1,1}+d_{1,2}+d_{1,3}+d_{1,4}

Image A的第二块切片与Image B的距离:d_{2,4}+d_{2,5}

Image A的第三块切片与Image B的距离:d_{3,5}

Image A的第四块切片与Image B的距离:d_{4,5}+d_{4,6}

Image A的第五块切片与Image B的距离:d_{5,6}

Image A的第六块切片与Image B的距离:d_{6,6}+d_{6,7}

Image A的第七块切片与Image B的距离:d_{7,7}

二、代码解析 

AlignedReID的代码与Triplet loss很相似,由于之前已经详细解析过Triplet loss源码了https://blog.csdn.net/m0_57541899/article/details/122243847?spm=1001.2014.3001.5501icon-default.png?t=LBL2https://blog.csdn.net/m0_57541899/article/details/122243847?spm=1001.2014.3001.5501这里直接在代码上解析。先把几个值得注意的函数写在前面:

1. torch.mean(object,dim,keepdim):对指定维度求平均,将指定的那维全变成1。如一个大小为(2,3)的tensor,其中2代表0维,3代表1维,对0维求平均,则tensor大小变为(1,3)

2.  permute(dims):将tensor的维度换位。如Imag的size是(28,28,3),就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的张量

3. squzze(a,axis=None):a为输入数组,axis用于删除指定维度

4. clamp(float min-number,float max-number,float parameter):用法为fmin<fp<fmax,则返回fp;fp>fmax,则返回fmax;fp<fmin,则返回fmin

from __future__ import print_function
import torch


def normalize(x, axis=-1):
  """Normalizing to unit length along the specified dimension.
  Args:
    x: pytorch Variable
  Returns:
    x: pytorch Variable, same shape as input      
  """
  x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)  # torch.norm(input,p,dim):calculate the norm in the specified dimension(p-dim)
  return x


def euclidean_dist(x, y):
  """
  Args:
    x: pytorch Variable, with shape [m, d]
    y: pytorch Variable, with shape [n, d]
  Returns:
    dist: pytorch Variable, with shape [m, n]
  """
  m, n = x.size(0), y.size(0)  # n:128  m:128
  xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)  # 1:means axis=1
  yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
  dist = xx + yy
  dist.addmm_(1, -2, x, y.t())  # dist=1*dist-2*(x@y^T);  # Note:dist.addmm_ & dist.addmm
  dist = dist.clamp(min=1e-12).sqrt()  # gain the distance matrix between samples
  return dist


def batch_euclidean_dist(x, y):  # x,y={Tensor:(128,8,128)}={Tensor:(N,m,d)}
  """
  Args:
    x(local_feat): pytorch Variable, with shape [N, m, d]=[128,8,128]
    y(local_feat[p_inds]): pytorch Variable, with shape [N, n, d]=[128,8,128]
  Returns:
    dist: pytorch Variable, with shape [N, m, n]
  """
  assert len(x.size()) == 3
  assert len(y.size()) == 3
  assert x.size(0) == y.size(0)
  assert x.size(-1) == y.size(-1)

  N, m, d = x.size()  # N:128 m:8 d=128
  N, n, d = y.size()  # n=8

  # shape [N, m, n]
  xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n)  # xx={Tensor:(128,8,8)}
  yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1)  # yy={Tensor:(128,8,8)}
  dist = xx + yy
  dist.baddbmm_(1, -2, x, y.permute(0, 2, 1))  # dist=1*dist-2*(x@y.permute);  # Note:dist.baddmm_ & dist.baddmm
  dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
  return dist


def shortest_dist(dist_mat):
  """Parallel version.
  Args:
    dist_mat: pytorch Variable, available shape:
      1) [m, n]
      2) [m, n, N], N is batch size
      3) [m, n, *], * can be arbitrary additional dimensions
  Returns:
    dist: three cases corresponding to `dist_mat`:
      1) scalar
      2) pytorch Variable, with shape [N]
      3) pytorch Variable, with shape [*]
  """
  m, n = dist_mat.size()[:2]
  # Just offering some reference for accessing intermediate distance.
  dist = [[0 for _ in range(n)] for _ in range(m)]  # initialization
  for i in range(m):
    for j in range(n):
      if (i == 0) and (j == 0):
        dist[i][j] = dist_mat[i, j]
      elif (i == 0) and (j > 0):
        dist[i][j] = dist[i][j - 1] + dist_mat[i, j]
      elif (i > 0) and (j == 0):
        dist[i][j] = dist[i - 1][j] + dist_mat[i, j]
      else:
        dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j]
  dist = dist[-1][-1]
  return dist


def local_dist(x, y):
  """
  Args:
    x: pytorch Variable, with shape [M, m, d]
    y: pytorch Variable, with shape [N, n, d]
  Returns:
    dist: pytorch Variable, with shape [M, N]
  """
  M, m, d = x.size()
  N, n, d = y.size()
  x = x.contiguous().view(M * m, d)
  y = y.contiguous().view(N * n, d)
  # shape [M * m, N * n]
  dist_mat = euclidean_dist(x, y)
  dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)
  # shape [M * m, N * n] -> [M, m, N, n] -> [m, n, M, N]
  dist_mat = dist_mat.contiguous().view(M, m, N, n).permute(1, 3, 0, 2)
  # shape [M, N]
  dist_mat = shortest_dist(dist_mat)
  return dist_mat


def batch_local_dist(x, y):
  """
  Args:
    x(local_feat): pytorch Variable, with shape [N, m, d]=[128,8,128]
    y(local_feat[p_inds]): pytorch Variable, with shape [N, n, d]=[128,8,128]
  Returns:
    dist: pytorch Variable, with shape [N]
  """
  assert len(x.size()) == 3  # judge the 'x' matrix whether is 3-dim,if not,report error
  assert len(y.size()) == 3
  assert x.size(0) == y.size(0)
  assert x.size(-1) == y.size(-1)

  # shape [N, m, n]
  dist_mat = batch_euclidean_dist(x, y)  # dist_mat={Tensor:(128,8,8)}
  dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)  # normalization
  # shape [N]
  dist = shortest_dist(dist_mat.permute(1, 2, 0))  # (128,8,8)-->(8,8,128),then calculate the shortest distance under dynamic planning
  return dist


def hard_example_mining(dist_mat, labels, return_inds=False):
  """For each anchor, find the hardest positive and negative sample.
  Args:
    dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
    labels: pytorch LongTensor, with shape [N]
    return_inds: whether to return the indices. Save time if `False`(?)
  Returns:
    dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
    dist_an: pytorch Variable, distance(anchor, negative); shape [N]
    p_inds: pytorch LongTensor, with shape [N]; 
      indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
    n_inds: pytorch LongTensor, with shape [N];
      indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
  NOTE: Only consider the case in which all labels have same num of samples, 
    thus we can cope with all anchors in parallel.
  """

  assert len(dist_mat.size()) == 2  # judge the 'dist_mat' matrix whether is 2-dim,if not,report error
  assert dist_mat.size(0) == dist_mat.size(1)  # judge the 'dist_mat' matrix whether is Square matrix,if not,report error
  N = dist_mat.size(0)  # gain the 'dist_mat' matrix length i.e.N(N: 128)

  # shape [N, N]
  is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())  # gain the positive sample
  is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())  # gain the negative sample

  # `dist_ap` means distance(anchor, positive)
  # both `dist_ap` and `relative_p_inds` with shape [N, 1]
  dist_ap, relative_p_inds = torch.max(
    dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)  # calculate the min similarity(i.e.max distance) between anchor and positive,and return the corresponding serial number
  # `dist_an` means distance(anchor, negative)
  # both `dist_an` and `relative_n_inds` with shape [N, 1]
  dist_an, relative_n_inds = torch.min(
    dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)  # calculate the max similarity(i.e.min distance) between anchor and negative,and return the corresponding serial number
  # shape [N]
  dist_ap = dist_ap.squeeze(1)  # compression dimension
  dist_an = dist_an.squeeze(1)

  if return_inds:
    # shape [N, N]
    ind = (labels.new().resize_as_(labels)
           .copy_(torch.arange(0, N).long())
           .unsqueeze( 0).expand(N, N))
    # shape [N, 1]
    p_inds = torch.gather(
      ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
    n_inds = torch.gather(
      ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
    # shape [N]
    p_inds = p_inds.squeeze(1)
    n_inds = n_inds.squeeze(1)
    return dist_ap, dist_an, p_inds, n_inds

  return dist_ap, dist_an


def global_loss(tri_loss, global_feat, labels, normalize_feature=True):
  """
  Args:
    tri_loss: a `TripletLoss` object
    global_feat: pytorch Variable, shape [N, C]
    labels: pytorch LongTensor, with shape [N]
    normalize_feature: whether to normalize feature to unit length along the 
      Channel dimension
  Returns:
    loss: pytorch Variable, with shape [1]
    p_inds: pytorch LongTensor, with shape [N]; 
      indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
    n_inds: pytorch LongTensor, with shape [N];
      indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
    =============
    For Debugging
    =============
    dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
    dist_an: pytorch Variable, distance(anchor, negative); shape [N]
    ===================
    For Mutual Learning
    ===================
    dist_mat: pytorch Variable, pairwise euclidean distance; shape [N, N]
  """
  if normalize_feature:
    global_feat = normalize(global_feat, axis=-1)
  # shape [N, N]
  dist_mat = euclidean_dist(global_feat, global_feat)
  dist_ap, dist_an, p_inds, n_inds = hard_example_mining(
    dist_mat, labels, return_inds=True)
  loss = tri_loss(dist_ap, dist_an)
  return loss, p_inds, n_inds, dist_ap, dist_an, dist_mat


def local_loss(
    tri_loss,
    local_feat,
    p_inds=None,
    n_inds=None,
    labels=None,
    normalize_feature=True):
  """
  Args:
    tri_loss: a `TripletLoss` object
    local_feat: pytorch Variable, shape [N, H, c] (NOTE THE SHAPE!)
    p_inds: pytorch LongTensor, with shape [N]; 
      indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
    n_inds: pytorch LongTensor, with shape [N];
      indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
    labels: pytorch LongTensor, with shape [N]
    normalize_feature: whether to normalize feature to unit length along the 
      Channel dimension
  
  If hard samples are specified by `p_inds` and `n_inds`, then `labels` is not 
  used. Otherwise, local distance finds its own hard samples independent of 
  global distance.
  
  Returns:
    loss: pytorch Variable,with shape [1]
    =============
    For Debugging
    =============
    dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
    dist_an: pytorch Variable, distance(anchor, negative); shape [N]
    ===================
    For Mutual Learning
    ===================
    dist_mat: pytorch Variable, pairwise local distance; shape [N, N]
  """
  if normalize_feature:
    local_feat = normalize(local_feat, axis=-1)  # local_feat={Tensor:(128,8,128)}
  if p_inds is None or n_inds is None:
    dist_mat = local_dist(local_feat, local_feat)
    dist_ap, dist_an = hard_example_mining(dist_mat, labels, return_inds=False)
    loss = tri_loss(dist_ap, dist_an)
    return loss, dist_ap, dist_an, dist_mat
  else:
    dist_ap = batch_local_dist(local_feat, local_feat[p_inds])  # dist_ap:local_dist_ap;  local_feat[p_inds]:positive_local_feat
    dist_an = batch_local_dist(local_feat, local_feat[n_inds])
    loss = tri_loss(dist_ap, dist_an)  # loss:local_loss
    return loss, dist_ap, dist_an

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值