import torch
from torch import nn
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input):
output = input + 1
return output
moudle = MyModule()
x = torch.tensor(1.0)
output = moudle(x)
print(output)
参考地址:https://www.bilibili.com/video/BV1hE411t7RN?p=16