Generalizing A Person Retrieval Model Hetero- and Homogeneously

1、论文:Generalizing A Person Retrieval Model Hetero- and Homogeneously

2、代码:https://github.com/zhunzhong07/HHL
在这里插入图片描述

文章idea
文章的网络框架和CamStyle的框架一样,比较简单。

创新点在于使用了triplet,并实现两个连接作用。

第一个:将源域与目标域连接起来(domain connectedness),称之为:heterogeneous learning。因为源域与目标域没有交集,它们之间所有行人属于不同类别。

第二个:实现镜头不变学习(camera invariance),称之为:homogeneous learning。因为作者使用StarGAN为每个图像生成了不同镜头风格的行人,这些生成的行人与真实的行人具有相同的标签。而目标集中类别不知道,可以认为所有真实图像之间的标签是不同的。这样可以连接镜头的风格。

数据的获取方式:
在源域随机选取128幅图像,包括16个类别,每个类别8幅图像;
在目标域随机选取16幅图像,例如在market有6个镜头,使用stargan为每幅图像生成6种风格的图像,获得每幅图像与对应的6个生成的图像。这6个假图像与真图像具有相同标签,这16幅图像具有不同标签,小批量大小为:112。

获取目标域图像代码,每个类别获得1+C幅图像(C为镜头数量):

class CameraPreprocessor(object):
    def __init__(self, dataset, root=None, target_path=None, target_camstyle_path=None, transform=None, num_cam=6):
        super(CameraPreprocessor, self).__init__()
        self.dataset = dataset
        self.root = root
        self.target_path = target_path
        self.target_camstyle_path = target_camstyle_path
        self.transform = transform
        self.num_cam = num_cam

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, indices):
        if isinstance(indices, (tuple, list)):
            return [self._get_single_item(index) for index in indices]
        return self._get_single_item(indices)

    def _get_single_item(self, index):
        img_all = []
        fname_all = []
        pid_all = []
        camid_all = []
        fname, _, camid = self.dataset[index]
        # randomly assign pseudo label to unlabeled target image
        pid = int(torch.rand(1) * 10000 + 1000)
        if self.root is not None:
            for i in range(self.num_cam+1):
                if i == 0:
                    fpath = osp.join(self.root, self.target_path, fname)
                else:
                    fpath = osp.join(self.root, self.target_camstyle_path,
                                     fname[:-4] + '_fake_' + str(i) + '.jpg')
                img = Image.open(fpath).convert('RGB')
                if self.transform is not None:
                    img = self.transform(img)
                img_all.append(img)
                fname_all.append(fname)
                pid_all.append(pid)
                camid_all.append(camid)

        return img_all, fname_all, pid_all, camid_all

结合源域的小批量128和目标域的小批量112作为一个批次,共同优化网络,里面总共包括16+16=32个行人类别。

最后的triplet loss的代码:

from __future__ import absolute_import
import torch
from torch import nn
from torch.autograd import Variable

class TripletLoss(nn.Module):
    def __init__(self, margin=0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets):
        n = inputs.size(0)
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().view(1))
            dist_an.append(dist[i][mask[i] == 0].min().view(1))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        # Compute ranking hinge loss
        y = dist_an.data.new()
        y.resize_as_(dist_an.data)
        y.fill_(1)
        y = Variable(y)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        prec = (dist_an.data > dist_ap).data.float().mean()
        return loss, prec
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值