作为初学者,看代码一度很迷惑,module中的forward函数中for循环,输入的Tensor数据是在同一个网络循环,还是依次向前推进了多个不同的网络。于是,我经过了下面的测试。
import torch.nn as nn
class Model1(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.block = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
x = self.block(x) + x
x = self.block(x) + x
x = self.block(x) + x
return x
class Model2(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.block = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
for i in range(3):
x = self.block(x) + x
return x
model1 = Model1(hidden_dim=10)
model2=Model2(hidden_dim=10)
print(model1)
print(model2)
得到以下结果:
Model1(
(block): Linear(in_feature=10, out_feature=10, bias=True)
)
Model2(
(block): Linear(in_feature=10, out_feature=10, bias=True)
)
然后我就悟了!一个Module的结构到底是由什么构成的,是__init__()还是forward()?结论是__init__()决定了Module有哪些网络,forward()决定了Module的网络是如何连接的。在forward()中无论如何调用__init__()中定义的某个网络,始终都是同一个网络。
那么文章开头那个问题的答案就有了,答案是:for循环中,通过的是同一个网络。