resnet升级

# -*- encoding: utf-8 -*-
"""
@File    : ResNet.py
@Time    : 2021-05-08 14:50
@Author  : XD
@Email   : gudianpai@qq.com
@Software: PyCharm
"""
from __future__ import absolute_import

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from IPython import embed

class ResNet50(nn.Module):
    def __init__(self, num_class, loss = {'softmax, metric'},**kwargs):
        super(ResNet50, self).__init__()
        resnet50 = torchvision.models.resnet50(pretrained = False)
        self.loss = loss
        self.base = nn.Sequential(*list(resnet50.children())[:-2])
        if not self.loss == {'metric'}:
            self.classifier = nn.Linear(2048, num_class)

    def forward(self, x):
        x = self.base(x)
        x = F.avg_pool2d(x,x.size()[2:])
        f = x.view(x.size(0), -1) #future
        #归一化特征
        #f = 1. * f / (torch.norm(f, 2, dim = -1, keepdim = True).expand_as(f) + 1e-12)

        if not self.training:
            return f
        y = self.classifier(f)
        if self.loss == {'softmax'}:
            return y
        elif self.loss == {'metric'}:
            return f
        elif self.loss == {'softmax','metric'}:
            return y, f
        else:
            print('loss setting error')






if __name__ == '__main__':
    model = ResNet50(num_class = 751)
    imgs = torch.rand(32, 3 , 256, 128)
    f = model(imgs)
    embed()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值