一、源码
阅读源码可知,identity模块不改变输入,直接返回输入。
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
Examples::
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 20])
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(Identity, self).__init__()
def forward(self, input: Tensor) -> Tensor:
return input
二、使用场景
举个例子,想用某个backbone时,最后一层本来是用作 分类的,用 softmax函数或者 fully connected 函数,但是用 nn.identtiy() 函数把最后一层替换掉,相当于得到分类之前的特征。
比如
backbone.fc, backbone.head = nn.Identity(), nn.Identity()