torch.nn.Module学习笔记

目录

1.add_module(name: str, module: Optional[Module]) → None

2.apply(fn: Callable[Module, None]) → T

3.bfloat16() → T

4.buffers(recurse: bool = True) → Iterator[torch.Tensor]

5.children() → Iterator[torch.nn.modules.module.Module]

6.cpu() → T

7.cuda(device: Union[int, torch.device, None] = None) → T

8.double() → T

9.dump_patches: BOOL = FALSE

10.eval() → T

11.extra_repr() → str

12.float() → T

13.forward(*input: Any) → None

14.half() → T

15.load_state_dict(state_dict: Dict[str, torch.Tensor], strict: bool = True)

16.modules() → Iterator[torch.nn.modules.module.Module]

17.named_buffers(prefix: str = '', recurse: bool = True) → Iterator[Tuple[str, torch.Tensor]]

18.named_children() → Iterator[Tuple[str, torch.nn.modules.module.Module]]

19.named_modules(memo: Optional[Set[Module]] = None, prefix: str = '')

20.named_parameters(prefix: str = '', recurse: bool = True) → Iterator[Tuple[str, torch.Tensor]]

21.parameters(recurse: bool = True) → Iterator[torch.nn.parameter.Parameter]

22.register_backward_hook(hook: Callable[[Module, Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[None, torch.Tensor]]) → torch.utils.hooks.RemovableHandle

23.register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None

24.register_forward_hook(hook: Callable[..., None]) → torch.utils.hooks.RemovableHandle

25.register_forward_pre_hook(hook: Callable[..., None]) → torch.utils.hooks.RemovableHandle

26.register_parameter(name: str, param: Optional[torch.nn.parameter.Parameter]) → None

27.requires_grad_(requires_grad: bool = True) → T

28.state_dict(destination=None, prefix='', keep_vars=False) → dict

29.to(*args, **kwargs)

30.train(mode: bool = True) → T

31.type(dst_type: Union[torch.dtype, str]) → T

32.zero_grad(set_to_none: bool = False) → None


 

CLASS torch.nn.Module

 

它是所有神经网络模块的基类。

你的模型也应该是这个类的子类。

模块也可以包含其它的模块,以树形结构将它们嵌套到一起。你可以将子模块赋值给普通的属性

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

以这种方式赋值的子模块将被注册,当你调用.to()函数时,它们的参数将被转换到相应的设备上。

1.add_module(name: str, module: Optional[Module]) → None

功能:在当前模块中添加一个子模块。这个子模块作为类的属性,可以使用相应的变量名访问。

参考链接:

2.apply(fn: Callable[Module, None]) → T

3.bfloat16() → T

4.buffers(recurse: bool = True) → Iterator[torch.Tensor]

官网解释

exampel:

>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)

5.children() → Iterator[torch.nn.modules.module.Module]

6.cpu() → T

7.cuda(device: Union[int, torch.device, None] = None) → T

8.double() → T

9.dump_patches: BOOL = FALSE

10.eval() → T

11.extra_repr() → str

12.float() → T

13.forward(*input: Any) → None

14.half() → T

15.load_state_dict(state_dict: Dict[str, torch.Tensor], strict: bool = True)

功能描述:复制来自state_dict()的参数和缓冲到这个模块和它的子节点。如果strict=True,state_dict的 keys 必须与模块的 state_dict() 函数返回的 keys 精确匹配。

参考链接:https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=torch%20nn%20module%20load_state_dict#torch.nn.Module.load_state_dict 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值