Pytorch什么时候开始调用forward

import torch
from torch import nn

class MLP(nn.Module):
    # 声明带有模型参数的层,这里声明了两个全连接层
    def __init__(self, **kwargs):
        # 调用MLP父类Module的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
        # 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Linear(784, 256) # 隐藏层
        self.act = nn.ReLU()
        self.output = nn.Linear(256, 10)  # 输出层


    # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
    def forward(self, x):
        print('调用forward函数')
        a = self.act(self.hidden(x))
        return self.output(a)




X = torch.rand(2, 784)
net = MLP() # 多层感知机,网络结构的初始化。

print(net)
print('*'*50)
net(X)

MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)
**************************************************
调用forward函数

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
是的,Pytorch中可以直接调用amsoftmax。 在PyTorch中,可以使用nn.CrossEntropyLoss()函数来计算softmax输出与实际标签之间的交叉熵损失。但是,如果你想使用amsoftmax,需要自定义损失函数。 以下是一个简单的amsoftmax实现示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class AMSoftmax(nn.Module): def __init__(self, in_feats, n_classes, m=0.35, s=30.0): super(AMSoftmax, self).__init__() self.m = m self.s = s self.in_feats = in_feats self.n_classes = n_classes self.weight = nn.Parameter(torch.Tensor(in_feats, n_classes)) nn.init.xavier_uniform_(self.weight) def forward(self, x, label): x_norm = F.normalize(x, p=2, dim=1) w_norm = F.normalize(self.weight, p=2, dim=0) logits = x_norm.mm(w_norm) target_logits = logits[torch.arange(0, x.size(0)), label].view(-1, 1) m_hot = torch.zeros_like(logits).scatter_(1, label.view(-1, 1), self.m) logits_m = logits - m_hot logits_scaled = logits_m * self.s loss = nn.CrossEntropyLoss()(logits_scaled, label) return loss ``` 在此实现中,我们通过继承nn.Module来创建一个自定义的AMSoftmax层。在前向传播中,我们首先使用F.normalize()函数对输入特征x和权重矩阵w进行L2归一化。然后,我们将二者相乘,得到logits。接着,我们从logits中提取出与目标标签对应的logit,并将其视为target_logits。接下来,我们创建一个大小与logits相同的张量m_hot,其中每个样本的目标类别位置用值为m的标量替换。最后,我们从logits_m中减去m_hot,然后将差乘以s,以得到缩放后的logits。最后,我们使用自定义的交叉熵损失函数计算损失并返回。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值