用一段小程序来说明
先看torch.nn.ReLU()
import torch
input = torch.Tensor([[1.0, 2.0],
[-1.0, 3.0]])
class demo(torch.nn.Module):
def __init__(self):
super(demo, self).__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x)
model = demo()
print(model(input))
>>
tensor([[1., 2.],
[0., 3.]])
torch.nn.ReLU必须在继承torch.nn.Module的情况下使用
再看torch.nn.functional.relu()
import torch
input = torch.Tensor([[1.0, 2.0],
[-1.0, 3.0]])
print(torch.nn.functional.relu(input))
>>
tensor([[1., 2.],
[0., 3.]])
torch.nn.functional.relu更像是独立的一个函数,可以自由使用