1 torch.nn.Identity( ) 的作用
torch.nn.Identity( ) 相当于一个恒等函数
f(x) = x
这个函数相当于输入什么就输出什么, 可以用在对已经设计好模型结构的修改, 比如模型的最后一层是 1000 分类, 我们可以将最后一层用 nn.Identity( ) 替换掉, 得到它之前学习的特征, 然后再自己设计最后一层的结构
在迁移学习中经常使用
2 示例
import torch
from torch import nn
from torch.nn import NLLLoss
import timm
class MiniModel(nn.Module):
def __init__(self, backbone, num_class, pretrained=False, backbone_ckpt=None):
super().__init__()
self.backbone = timm.crear_model(backbone, pretrained=pretrained, checkpoint_path=backbone_ckpt)
self.head = nn.Linear(self.backbone.get_classifier().in_features, num_class)
# 替代最后一层的全连接网络
self.backbone.head.fc = nn.Identity()
self.loss_fn = NLLLoss()
def forward(self, image, label):
embed = self.backbone(image)
logit = self.head(embed)
if label is not None:
logit_logsoftmax = torch.log_softmax(logit, 1)
loss = self.loss_fn(logit_logsoftmax, label)
return {"prediction": logit, "loss": loss}
return {"prediction": logit}