pytorch中的forward是如何调用的?

0 背景

给出一段代码,可以看到CNN是一个类,model = CNN()为类的实例化。
问题:outputs = model(imgs)一个类的实例,为什么能作为一个函数直接调用?并且调用的是forward方法?

...
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            # 卷积层
            ...
        )
    def forward(self, x):
        return self.net(x)

...

model = CNN()  # 模型实例化

...

for epoch in range(epochs):
    train_loss = 0
    for imgs, labels in train_dataloader:
        imgs = imgs
        labels = labels
        outputs = model(imgs)    # 前向计算
        loss = loss_func(outputs, labels)    # 损失函数计算
        train_loss += loss.item() * imgs.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    # 更新优化器

1 原码解析

版本信息:

python 3.10.0
pytorch 2.1.0

  • torch.nn.Module的代码,在torch\nn\modules\module.py,下面这段原码是精简版的。
    调用原理: python类的__call__方法支持实例可以作为函数来调用,即调用实例就是调用__call__
    核心逻辑:nn.Module__call__被定义为__wrapped_call_impl,进而会调用到self.forward
...
class Module:
...
    def __init__(self, *args, **kwargs) -> None:
    ...
    __call__ : Callable[..., Any] = _wrapped_call_impl
    ...
    def _wrapped_call_impl(self, *args, **kwargs):
        if self._compiled_call_impl is not None:
            return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
        else:
            return self._call_impl(*args, **kwargs)

    def _call_impl(self, *args, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
                or _global_backward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
            return forward_call(*args, **kwargs)
...

2 引申

有了这个核心逻辑,我们也可以实现一个简单的调用。
这里,test方法可以类比于nn.Moduleforward方法。

另外,在Python中,Callable[..., Any]是一个类型注解,用来描述一个可以被调用的对象。
这里的语法解释如下:
Callable是一个类型,表示一个可调用的类型。
[...]表示参数类型是可变的。这意味着可以传递任意数量的参数给这个可调用的对象。
Any表示返回类型是任意的,即这个可调用的对象可以返回任何类型的结果。

from typing import Any, Callable

class Module:
    def __init__(self):
        return

    def test(self, *args):
        print("success:", sum(args))
        return

    __call__ : Callable[..., Any] = test

A = Module()
output = A(5,6,7)

输出

success: 18

  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
是的,Pytorch可以直接调用amsoftmax。 在PyTorch,可以使用nn.CrossEntropyLoss()函数来计算softmax输出与实际标签之间的交叉熵损失。但是,如果你想使用amsoftmax,需要自定义损失函数。 以下是一个简单的amsoftmax实现示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class AMSoftmax(nn.Module): def __init__(self, in_feats, n_classes, m=0.35, s=30.0): super(AMSoftmax, self).__init__() self.m = m self.s = s self.in_feats = in_feats self.n_classes = n_classes self.weight = nn.Parameter(torch.Tensor(in_feats, n_classes)) nn.init.xavier_uniform_(self.weight) def forward(self, x, label): x_norm = F.normalize(x, p=2, dim=1) w_norm = F.normalize(self.weight, p=2, dim=0) logits = x_norm.mm(w_norm) target_logits = logits[torch.arange(0, x.size(0)), label].view(-1, 1) m_hot = torch.zeros_like(logits).scatter_(1, label.view(-1, 1), self.m) logits_m = logits - m_hot logits_scaled = logits_m * self.s loss = nn.CrossEntropyLoss()(logits_scaled, label) return loss ``` 在此实现,我们通过继承nn.Module来创建一个自定义的AMSoftmax层。在前向传播,我们首先使用F.normalize()函数对输入特征x和权重矩阵w进行L2归一化。然后,我们将二者相乘,得到logits。接着,我们从logits提取出与目标标签对应的logit,并将其视为target_logits。接下来,我们创建一个大小与logits相同的张量m_hot,其每个样本的目标类别位置用值为m的标量替换。最后,我们从logits_m减去m_hot,然后将差乘以s,以得到缩放后的logits。最后,我们使用自定义的交叉熵损失函数计算损失并返回。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值