Pytorch 中的 forward理解

前言

我们在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。

class Module(nn.Module):
    def __init__(self):
        super().__init__()
        # ......

    def forward(self, x):
        # ......
        return x


data = ......  # 输入数据

# 实例化一个对象
model = Module()

# 前向传播
model(data)

# 而不是使用下面的
# model.forward(data)  

但是实际上model(data)是等价于model.forward(data),这是为什么呢???下面我们来分析一下原因。

forward函数

model(data)之所以等价于model.forward(data),就是因为在类(class)中使用了__call__函数,对__call__函数不懂得可以点击链接:可调用:__call__函数

class Student:
    def __call__(self):
        print('I can be called like a function')


a = Student()
a()

输出结果:

I can be called like a function

由上面的__call__函数可知,我们可以将forward函数放到__call__函数中进行调用:

class Student:
    def __call__(self, param):
        print('I can called like a function')
        print('传入参数的类型是:{}   值为: {}'.format(type(param), param))

        res = self.forward(param)
        return res

    def forward(self, input_):
        print('forward 函数被调用了')

        print('in  forward, 传入参数类型是:{}  值为: {}'.format(type(input_), input_))
        return input_


a = Student()

input_param = a('data')
print("对象a传入的参数是:", input_param)

输出结果:

I can called like a function
传入参数的类型是:<class 'str'>   值为: data
forward 函数被调用了
in  forward, 传入参数类型是:<class 'str'>  值为: data
对象a传入的参数是: data

到这里我们就可以明白了为什么model(data)等价于model.forward(data),是因为__call__函数中调用了forward函数。

PyTorch,labels主要用于表示输入数据的目标输出值,特别是在训练神经网络模型时非常关键。它们用于指导模型学习如何将输入映射到正确的输出。 ### PyTorchlabels的作用 1. **监督学习**: 在监督学习任务,模型通过比较其预测结果与实际标签(labels),来进行损失函数计算,并据此更新权重,以最小化预测误差。 2. **分任务**: 对于图像分、文本分等任务,每个样本通常都会有一个对应的标签,这代表了样本属于特定别。例如,在手写数字识别任务,如果我们要识别的是0-9之间的数字,那么对于每一个输入图像,我们都有一个期望的标签作为目标。 3. **回归任务**: 在回归任务,labels可以理解为连续的数值,比如房价预测、温度预测等场景。 4. **损失计算**: 模型训练过程需要计算预测值与真实标签之间的差异,这个差异量化的结果就是损失值,损失函数的选择直接影响到模型优化的方向。 5. **评估性能**: 使用验证集或测试集时,通过比较模型的预测值与真实的标签,可以评估模型的泛化能力和预测准确率。 ### 使用示例 假设我们在使用PyTorch构建一个简单的二元分模型: ```python import torch from torch import nn, optim import torchvision.transforms as transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader # 定义模型 class SimpleClassifier(nn.Module): def __init__(self): super(SimpleClassifier, self).__init__() self.fc = nn.Linear(784, 1) # 输入28x28的图像展平后的维度,输出1维表示概率分布 def forward(self, x): return torch.sigmoid(self.fc(x)) # 加载数据集并预处理 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) train_dataset = MNIST(root='./data', train=True, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 创建模型实例 model = SimpleClassifier() # 初始化损失函数和优化器 criterion = nn.BCELoss() # 适用于二元分的任务 optimizer = optim.SGD(model.parameters(), lr=0.01) # 训练循环 for epoch in range(5): # 运行5个周期 for inputs, labels in train_loader: optimizer.zero_grad() # 前向传播 outputs = model(inputs.view(-1, 784)) # 计算损失 loss = criterion(outputs, labels.float()) # 反向传播和优化 loss.backward() optimizer.step() ``` 在这个例子,`labels`是从MNIST数据集获取的真实标签,通常是一个张量数组,形状与输入图片的批次大小一致。每个元素代表该批次对应样本的分标签。 ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值