大部分nn中的层class都有nn.function对应,其区别是:
- nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Parameter
- nn.functional中的函数更像是纯函数,由def function(input)定义。
由于两者性能差异不大,所以具体使用取决于个人喜好。对于激活函数和池化层,由于没有可学习参数,一般使用nn.functional完成,其他的有学习参数的部分则使用类。但是Droupout由于在训练和测试时操作不同,所以建议使用nn.Module实现,它能够通过model.eval加以区分。
一、nn.functional函数基本使用
1 2 3 4 5 6 7 8 9 10 11 12 13 | import torch as t import torch.nn as nn from torch.autograd import Variable as V input_ = V(t.randn( 2 , 3 )) model = nn.Linear( 3 , 4 ) output1 = model(input_) output2 = nn.functional.linear(input_, model.weight, model.bias) print (output1 = = output2) b1 = nn.functional.relu(input_) b2 = nn.ReLU()(input_) print (b1 = = b2) |
二、搭配使用nn.Module和nn.functional
并不是什么难事,之前有接触过,nn.functional不需要放入__init__进行构造,所以不具有可学习参数的部分可以使用nn.functional进行代替。
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | # Author : Hellcat # Time : 2018/2/11 import torch as t import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__( self ): super (LeNet, self ).__init__() self .conv1 = nn.Conv2d( 3 , 6 , 5 ) self .conv2 = nn.Conv2d( 6 , 16 , 5 ) self .fc1 = nn.Linear( 16 * 5 * 5 , 120 ) self .fc2 = nn.Linear( 120 , 84 ) self .fc3 = nn.Linear( 84 , 10 ) def forward( self ,x): x = F.max_pool2d(F.relu( self .conv1(x)),( 2 , 2 )) x = F.max_pool2d(F.relu( self .conv2(x)), 2 ) x = x.view(x.size()[ 0 ], - 1 ) x = F.relu( self .fc1(x)) x = F.relu( self .fc2(x)) x = self .fc3(x) return x |
三、nn.functional函数构造nn.Module类
两者主要的区别就是对于可学习参数nn.Parameter的识别能力,所以构造时添加了识别能力即可。
『PyTorch』第七弹_nn.Module扩展层
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | class Linear(nn.Module): def __init__( self , in_features, out_features): # nn.Module.__init__(self) super (Linear, self ).__init__() self .w = nn.Parameter(t.randn(out_features, in_features)) # nn.Parameter是特殊Variable self .b = nn.Parameter(t.randn(out_features)) def forward( self , x): # wx+b return F.linear(x, self .w, self .b) layer = Linear( 4 , 3 ) input = V(t.randn( 2 , 4 )) output = layer( input ) print (output) |
Variable containing:
1.7498 -0.8839 0.5314
-2.4863 -0.6442 1.1036
[torch.FloatTensor of size 2x3]