![在这里插入图片描述](https://img-blog.csdnimg.cn/20210605145422758.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NjgxNTMzMA==,size_16,color_FFFFFF,t_70)
from __future__ import absolute_import
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from aligned.HorizontalMaxPool2D import HorizontalMaxPool2d
__all__ = ['ResNet50', 'ResNet101']
class ResNet50(nn.Module):
def __init__(self, num_classes, loss={'softmax'}, aligned=False, **kwargs):
super(ResNet50, self).__init__()
self.loss = loss
resnet50 = torchvision.models.resnet50(pretrained = False)
self.base = nn.Sequential(*list(resnet50.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048
self.aligned = aligned
self.horizon_pool = HorizontalMaxPool2d()
if self.aligned:
self.bn = nn.BatchNorm2d(2048)
self.relu = nn.ReLU(inplace = True)
self.conv1 = nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
x = self.base(x)
if self.aligned and self.training:
lf = self.bn(x)
lf = self.relu(lf)
lf = self.horizon_pool(lf)
lf = self.conv1(lf)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.aligned:
return y, f, lf
else:
if self.loss == {'softmax'}:
return y
elif self.loss == {'metric'}:
return f
elif self.loss == {'softmax', 'metric'}:
return y, f
else:
print('loss settings error')
if __name__ == '__main__':
model = ResNet50(num_classes = 751, loss = {'softmax', 'metric'}, aligned = True)
imgs = torch.Tensor(32, 3, 256, 128)
y, f, local = model(imgs)
print(y.size())
print(f.size())
print(local.size())
torch.Size([32, 751])
torch.Size([32, 2048])
torch.Size([32, 128, 8, 1])