- nn.Module是Pytorch封装的一个类,是搭建神经网络时需要继承的父类:
import torch
import torch.nn as nn
# 括号中加入nn.Module(父类)。Test2变成子类,继承父类(nn.Module)的所有特性。
class Test2(nn.Module):
def __init__(self): # Test2类定义初始化方法
super(Test2, self).__init__() # 父类初始化
self.M = nn.Parameter(torch.ones(10))
def weightInit(self):
print('Testing')
def forward(self, n):
# print(2 * n)
print(self.M * n)
self.weightInit()
# 调用方法
network = Test2()
network(2) # 2赋值给forward(self, n)中的n。
……省略一部分代码……
# 因为Test2是nn.Module的子类,所以也可以执行父类中的方法。如:
model_dict = network.state_dict() # 调用父类中的方法state_dict(),将Test2中训练参数赋值model_dict。
for k, v in model_dict.items(): # 查看自己网络参数各层名称、数值
print(k) # 输出网络参数名字
# print(v) # 输出网络参数数值
继承nn.Module的子类程序是从forward()方法开始执行的,如果要想执行其他方法,必须把它放在forward()方法中。这一点与python中继承有稍许的不同。