PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别

问题描述:

提示:刚入门深度学习,记录一些犯下的小错误:

由于本周开始试图复现华为的CTR库以增加记忆,熟悉代码细节,没想到第一天看基础模块的时候就遇到了麻烦,在torch.utils类中,有如下获取损失函数的代码块:

def get_loss_fn(loss):
    if isinstance(loss, str):
        if loss in ["bce", "binary_crossentropy", "binary_cross_entropy"]:
            loss = "binary_cross_entropy"
    try:
        loss_fn = getattr(torch.functional.F, loss)
    except:
        try:
            from . import losses
            loss_fn = getattr(losses, loss)
        except:
            raise NotImplementedError("loss={} is not supported.".format(loss))
    return loss_fn

其中getattr()函数是用于返回一个对象属性值(Tip: class中的方法也是一种对象属性),因此可以看出第6行代码的作用就是返回torch.functional.F这个类中的loss函数,那么问题来了:上面代码片中的torch.functional.F是哪个类呢,或者说是哪个模块呢?之前在学习PyTorch的过程中只接触过其中的:

import torch.nn.functional as F

那么这个torch.functional.Ftorch.nn.functional有何区别?


解惑

因此抱着分辨清楚的目的,查看PyTorch官方文档,我发现只有torch.nn.functional才有一系列的loss函数的实现,而输入关键词torch.functional在搜索引擎上基本找不到相关的资料,返回的搜索结果都是与前者相关的文档。于是我决定去看源码弄清楚:

分别找到两个模块所在位置
可以看到这两个模块显然是不同的模块!!!而后我打开torch.functional.py文件,出现了我无语的一幕,原来在torch.functional.py的第一行就是这么写的:

无语了啊
问题解决了,torch.functional .F指向的就是torch.nn.functional,可能刚开始试图复现这个CTR库吧,实在搞不懂作者为什么不直接直接使用torch.nn.functional来指代?

END~


  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值