torch.nn 与 torch.nn.functional的区别?如何选择?

前言:

在ptorch中,torch.nntorch.nn.functional模块下包含了许多功能相同、用法相似的方法,比如:

1. torch.nn.Softmax(dim=1)(x) 和 torch.nn.functional.softmax(x, dim=1)实现了对x在行维度上进行softmax,

2. torch.nn.Conv2d(3, 6, 5)(x) 和 torch.nn.functional.conv2d(x, weight=torch.rand(6, 3, 5, 5))实现了对x进行同尺寸的卷积,

3. torch.nn.ReLU()(x) 和 torch.nn.functional.relu(x)实现了对x进行relu激活。

接下来我将总结这两个模块下方法的区别,联系及如何选择。

 

一、区别:

1. 类型:

(1) torch.nn中包含的是封装好的类,继承自nn.Module,调用时先实例化,然后在__init__()中初始化,再在forward()中进行操作。

(2) torch.nn.functional中包含的是实现好的函数,直接通过接口调用。

2. 调用方法:

以二维卷积操作为例:

(1) 调用torch.nn.Conv2d()类时,会先进行实例化,在构造函数中帮你定义好weight, bias变量,你只需传入相关参数,而不需要手动维护这些变量。然后会把定义好的变量作为参数传入nn.functional.conv2d()函数中,以函数调用的方法实现卷积。调用方式为torch.nn.Conv2d()(x)。

(2) 调用torch.nn.functional.conv2d()函数时,需要手动地创建好weight, bias变量,然后作为参数传入函数中调用。调用方式为torch.nn.functional.conv2d(x)

3. 继承自nn.Module的方法

(1) 前面提到,torch.nn下的类继承自nn.Module,所以nn模块下的类除了具有nn.functional函数的功能外,还继承了nn.Module的属性和方法,包括我们常用的train(), eval(), cuda()等等。

(2) 此外由于继承自nn.Module,torch.nn模块中的类能与nn.Sequential结合使用,而nn.functional中的函数无法与nn.Sequential结合。

 

二、联系:

torch.nn的类会在forward()方法中调用torch.nn.functional的函数,所以可以理解为nn模块中的方法是对nn.functional模块中方法的更高层的封装。

 

三、如何选择:

1. 何时选择torch.nn 

在定义深度神经网络的layer时推荐使用nn模块。一是因为当定义有变量参数的层时(比如conv2d, linear, batch_norm),nn模块会帮助我们初始化好变量,而我们只需要传入一些参数;二是因为model类本身是nn.Module,看起来会比较协调统一;三是因为可以结合nn.Sequential。

此外当使用dropout时推荐使用nn模块,因为可以在测试阶段通过eval()方法方便地关闭dropout。

2. 何时选择torch.nn.functional

nn.functional中的函数相比nn更偏底层,所以虽然封装性不高,但透明度很高,可以在其基础上定义出自己想要的功能。

  • 12
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值