# -*- encoding: utf-8 -*-
"""
@File : ResNet.py
@Time : 2021-05-08 14:50
@Author : XD
@Email : gudianpai@qq.com
@Software: PyCharm
"""
from __future__ import absolute_import
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from IPython import embed
class ResNet50(nn.Module):
def __init__(self, num_class, loss = {'softmax, metric'},**kwargs):
super(ResNet50, self).__init__()
resnet50 = torchvision.models.resnet50(pretrained = False)
self.loss = loss
self.base = nn.Sequential(*list(resnet50.children())[:-2])
if not self.loss == {'metric'}:
self.classifier = nn.Linear(2048, num_class)
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x,x.size()[2:])
f = x.view(x.size(0), -1) #future
#归一化特征
#f = 1. * f / (torch.norm(f, 2, dim = -1, keepdim = True).expand_as(f) + 1e-12)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'softmax'}:
return y
elif self.loss == {'metric'}:
return f
elif self.loss == {'softmax','metric'}:
return y, f
else:
print('loss setting error')
if __name__ == '__main__':
model = ResNet50(num_class = 751)
imgs = torch.rand(32, 3 , 256, 128)
f = model(imgs)
embed()
resnet升级
最新推荐文章于 2024-04-01 14:48:38 发布