使用以下的代码证明:
import torch.nn as nn
class A(nn.Module):
def __init__(self):
pass
def forward(self, x):
return x
def __call__(self, x):
print("Start running __call__ function")
return self.forward(x)
a = A()
t1 = torch.rand((2, 5))
print(a(t1))
程序输出:
Start running __call__ function
tensor([[0.9715, 0.2507, 0.9550, 0.9101, 0.9695],
[0.0976, 0.8725, 0.7675, 0.9817, 0.7045]])