今天在看pytorch的代码时,看到了torch.nn 和 torch.nn.functional,然后查了两个模块的官方doc,也没有看明白有啥区别,然后就查了一下资料,这里记录一下,方便以后查阅。
【pytorch】torch.nn 与 torch.nn.functional 的区别
torch.nn 与 torch.nn.functional 的区别
torch.nn.X | torch.nn.functional.X |
---|---|
是 类 | 是函数 |
结构中包含所需要初始化的参数 | 需要在函数外定义并初始化相应参数,并作为参数传入 |
一般情况下放在_init_ 中实例化,并在forward中完成操作 | 一般在_init_ 中初始化相应参数,在forward中传入 |
代码解释
torch.nn
torch.nn 这个模块下面存的主要是 Module类。
以torch.nn.Conv2d为例, 也就是说 torch.nn.Conv2d这种"函数"其实是个 Module类。
在实例化类后会初始化2d卷积所需要的参数. 这些参数会在你做forward和 backward之后根据loss进行更新,所以通常存放在定义模型的 init() 中.如:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
#其实这里就是类的实例化,需要定义初始参数
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.act = nn.ReLU()
def forward(self, x):
x = self.act(self.conv1(x))
return x
那在定义模型时,可不可以把nn.Conv2d写在forward处?
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.act = nn.ReLU()
def forward(self, x):
# 把卷积函数写在forward中
x= nn.Conv2d(3, 6, 3, 1, 1)(x)
x = self.act(x)
return x
把nn.Conv2d写在forward中就相当于模型每次跑forward的时候,都重新实例化了nn.Conv2d和nn.Conv2d的参数,导致模型学不到参数.
torch.nn.functional
torch.nn.functional.x 为函数。
与torch.nn不同, torch.nn.x中包含了初始化需要的参数等 attributes 而torch.nn.functional.x则需要把相应的weights 作为输入参数传递,才能完成运算, 所以用torch.nn.functional创建模型时需要创建并初始化相应参数.
例如:
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.act = nn.ReLU()
self.weighs = nn.Parameter(torch.rand(x,x,x,x)) # 初始化参数
self.bias = nn.Parameter(torch.rand(x)) # 初始化参数
def forward(self, x):
# 把卷积函数写在forward中,把w和b传入函数
x= F.conv2d(x,self.weighs,self.bias)
x = self.act(x)
return x