目录
1.add_module(name: str, module: Optional[Module]) → None
2.apply(fn: Callable[Module, None]) → T
4.buffers(recurse: bool = True) → Iterator[torch.Tensor]
5.children() → Iterator[torch.nn.modules.module.Module]
7.cuda(device: Union[int, torch.device, None] = None) → T
13.forward(*input: Any) → None
15.load_state_dict(state_dict: Dict[str, torch.Tensor], strict: bool = True)
16.modules() → Iterator[torch.nn.modules.module.Module]
17.named_buffers(prefix: str = '', recurse: bool = True) → Iterator[Tuple[str, torch.Tensor]]
18.named_children() → Iterator[Tuple[str, torch.nn.modules.module.Module]]
19.named_modules(memo: Optional[Set[Module]] = None, prefix: str = '')
20.named_parameters(prefix: str = '', recurse: bool = True) → Iterator[Tuple[str, torch.Tensor]]
21.parameters(recurse: bool = True) → Iterator[torch.nn.parameter.Parameter]
23.register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None
24.register_forward_hook(hook: Callable[..., None]) → torch.utils.hooks.RemovableHandle
25.register_forward_pre_hook(hook: Callable[..., None]) → torch.utils.hooks.RemovableHandle
26.register_parameter(name: str, param: Optional[torch.nn.parameter.Parameter]) → None
27.requires_grad_(requires_grad: bool = True) → T
28.state_dict(destination=None, prefix='', keep_vars=False) → dict
30.train(mode: bool = True) → T
31.type(dst_type: Union[torch.dtype, str]) → T
32.zero_grad(set_to_none: bool = False) → None
它是所有神经网络模块的基类。
你的模型也应该是这个类的子类。
模块也可以包含其它的模块,以树形结构将它们嵌套到一起。你可以将子模块赋值给普通的属性。
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
以这种方式赋值的子模块将被注册,当你调用.to()函数时,它们的参数将被转换到相应的设备上。
1.add_module(name: str, module: Optional[Module]) → None
功能:在当前模块中添加一个子模块。这个子模块作为类的属性,可以使用相应的变量名访问。
参考链接:
2.apply(fn: Callable[Module, None]) → T
3.bfloat16() → T
4.buffers(recurse: bool = True) → Iterator[torch.Tensor]
exampel:
>>> for buf in model.buffers():
>>> print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
5.children() → Iterator[torch.nn.modules.module.Module]
6.cpu() → T
7.cuda(device: Union[int, torch.device, None] = None) → T
8.double() → T
9.dump_patches: BOOL = FALSE
10.eval() → T
11.extra_repr() → str
12.float() → T
13.forward(*input: Any) → None
14.half() → T
15.load_state_dict(state_dict: Dict[str, torch.Tensor], strict: bool = True)
功能描述:复制来自state_dict()的参数和缓冲到这个模块和它的子节点。如果strict=True,state_dict的 keys 必须与模块的 state_dict() 函数返回的 keys 精确匹配。