1、概念
nn.Module
是 PyTorch 中的一个基类,用于构建神经网络模型。所有的神经网络模型都应该继承自 nn.Module
类,这样可以利用 PyTorch 提供的模块化和自动求导等功能。
nn.Module
提供了一些有用的方法和属性,包括:
-
参数管理:
nn.Module
能够跟踪并管理所有注册为模型参数的张量,通过parameters()
方法可以访问模型的所有参数。 -
子模块管理: 你可以在模型中包含其他模型作为子模块,通过
children()
和modules()
方法访问模型的所有子模块。 -
前向传播方法(
forward
): 所有的模型都需要实现forward
方法,该方法定义了输入数据的前向传播过程。 -
to 方法: 用于将模型移动到不同的设备(例如,从 CPU 到 GPU)。
使用 nn.Module
可以更轻松地构建、管理和组织复杂的神经网络结构。通过继承和使用 nn.Module
,你可以定义自己的神经网络层和模型,并且 PyTorch 将负责自动求导、参数优化等底层细节。
2、代码
import torch from torch import nn class Xuex(nn.Module): def __init__(self) -> None: super().__init__() def forward(self,input): output=input+1 return output xuex=Xuex() x=torch.tensor(9) out=xuex(x) print(f"out:{out}")
鼠标左键点击行数旁边的空白区域添加断点,再点击这个昆虫图标
再点击“单步执行代码”
可以看到目前程序执行的位置
再次点击“单步执行”,已经创建名为“xuex”的网络框架了
“单步执行”,传入参数