nn.Module基本框架
创建神经网络有一个统一的模板,如下所示
第一步
导入需要的包
import torch import torch.nn as nn
第二步
创建神经网络
class GSW(nn.Module): def __init__(self): super(GSW, self).__init__() def forward(self, input): output = input + 1 return output
第三步
训练网络 此处简单举例,并未使用数据集
gsw = GSW() x = torch.tensor(1.0) output = gsw(x) print(output)
完整代码
# 开发时间: 2021/11/21 17:30 import torch import torch.nn as nn class GSW(nn.Module): def __init__(self): super(GSW, self).__init__() def forward(self, input): output = input + 1 return output gsw = GSW() x = torch.tensor(1.0) output = gsw(x) print(output)