1.Module类介绍
2.forward()方法
子类继承nn.Module需要重写forward()方法
3.nn.Module的代码实战
import torch
from torch import nn
class Mymodule(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def forward(self, input): # 注意是forward()不是__forward__()
output = input + 1
return output
tensor = torch.tensor(1.0)
my_module = Mymodule() # 声明Mymodule类型对象
print(my_module(tensor)) # my_module(tensor)调用forward()方法