AlignedReID 网络改写

在这里插入图片描述

from __future__ import absolute_import

import torch
from torch import nn
from torch.nn import functional as F
import torchvision

from aligned.HorizontalMaxPool2D import HorizontalMaxPool2d

__all__ = ['ResNet50', 'ResNet101']

class ResNet50(nn.Module):
    def __init__(self, num_classes, loss={'softmax'}, aligned=False, **kwargs):
        super(ResNet50, self).__init__()
        self.loss = loss
        resnet50 = torchvision.models.resnet50(pretrained = False)
        self.base = nn.Sequential(*list(resnet50.children())[:-2])
        self.classifier = nn.Linear(2048, num_classes)
        self.feat_dim = 2048 # feature dimension
        self.aligned = aligned
        self.horizon_pool = HorizontalMaxPool2d()
        if self.aligned:
            self.bn = nn.BatchNorm2d(2048)
            self.relu = nn.ReLU(inplace = True)
            self.conv1 = nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x):

        x = self.base(x)

        if self.aligned and self.training:
            lf = self.bn(x)
            lf = self.relu(lf)
            lf = self.horizon_pool(lf)
            lf = self.conv1(lf)

        x = F.avg_pool2d(x, x.size()[2:])
        f = x.view(x.size(0), -1)
        #f = 1. * f / (torch.norm(f, 2, dim=-1, keepdim=True).expand_as(f) + 1e-12)
        if not self.training:
            return f
        y = self.classifier(f)
        if self.aligned:
            return y, f, lf
        else:
            if self.loss == {'softmax'}:
                return y
            elif self.loss == {'metric'}:
                return f
            elif self.loss == {'softmax', 'metric'}:
                return y, f
            else:
                print('loss settings error')
if __name__ == '__main__':
    model = ResNet50(num_classes = 751, loss = {'softmax', 'metric'}, aligned = True)
    imgs = torch.Tensor(32, 3, 256, 128)
    y, f, local = model(imgs)
    print(y.size())
    print(f.size())
    print(local.size())
torch.Size([32, 751])
torch.Size([32, 2048])
torch.Size([32, 128, 8, 1])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值