import torch
from torch import nn # PyTorch 的神经网络模块,提供了各种构建神经网络的组件和工具。
class Z(nn.Module): # 定义一个简单的神经网络模型 Z,它继承自 nn.Module:
def __init__(self): # Z 类的构造函数。
super(Z, self).__init__() # 调用父类 nn.Module 的构造函数,初始化父类的一些必要属性。
def forward(self, input): # 定义前向传播方法 forward,每当你将数据传入模型时,PyTorch 会自动调用这个方法。
input += 8
return input
# 创建了 Z 类的一个实例,命名为 Zilliax。
Zilliax = Z()
# 创建一个值为 1 的张量
x = torch.tensor(1)
# 由于 PyTorch 张量会以 tensor(...) 格式输出,所以打印的结果是tensor(9)
output = Zilliax(x)
print(output)
Pytorch最简单模型构建
于 2024-08-07 10:16:38 首次发布