Pytorch 中的 torch.nn.Identity( ) 的作用

1 torch.nn.Identity( ) 的作用

torch.nn.Identity( ) 相当于一个恒等函数

f(x) = x

  • 这个函数相当于输入什么就输出什么, 可以用在对已经设计好模型结构的修改, 比如模型的最后一层是 1000 分类, 我们可以将最后一层用 nn.Identity( ) 替换掉, 得到它之前学习的特征, 然后再自己设计最后一层的结构

  • 在迁移学习中经常使用

2 示例

import torch
from torch import nn
from torch.nn import NLLLoss
import timm


class MiniModel(nn.Module):
    
    def __init__(self, backbone, num_class, pretrained=False, backbone_ckpt=None):
        
        super().__init__()
        self.backbone = timm.crear_model(backbone, pretrained=pretrained, checkpoint_path=backbone_ckpt)
        self.head = nn.Linear(self.backbone.get_classifier().in_features, num_class)
        # 替代最后一层的全连接网络
        self.backbone.head.fc = nn.Identity()
        self.loss_fn = NLLLoss()
        
    def forward(self, image, label):
        embed = self.backbone(image)
        logit = self.head(embed)
        
        if label is not None:
            logit_logsoftmax = torch.log_softmax(logit, 1)
            loss = self.loss_fn(logit_logsoftmax, label)
            return {"prediction": logit, "loss": loss}
        return {"prediction": logit}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值