【pytorch API笔记5】nn.module类常见成员函数


参考
【PyTorch】torch.nn.Module 源码分析

多达48个函数,这里简单记录一下常见函数的作用
先创建一个Module,以这个为例

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

net = Model()

一、cpu(self)

将所有的参数和内存放在cpu上

net.cpu()  # 将所有的参数和内存放在cpu上

二、cuda(self, device=None)

将所有的参数和内存放在gpu上

net.gpu("cuda:0")  # 将所有的参数和内存放在gpu上

三、apply(self, fn)

将Module及其所有的SubModule传进给定的fn函数操作一遍。举个例子,我们可以用这个函数来对Module的网络模型参数用指定的方法初始化。

def init_weights(m): # 将所有子模型的linear参数赋值为1
     print(m)
     if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

四、type(self, dst_type)

type函数是将所有parameters和buffers都转成指定的目标类型dst_type

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.type(dst_type=torch.float16)
for model in net.children():
    print(model.weight.data)

五、float(self)、double(self)、half(self)、bfloat16(self)

float、double和half这三个函数是将所有floating point parameters分别转成float datatype、double datatype和half datatype。torch.Tensor.float即torch.float32;torch.Tensor.double即torch.float64;torch.Tensor.half即torch.float16。

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.half()
for model in net.children():
     print(model.weight.data)

六、to(self, *args, **kwargs)

函数to的作用是原地 ( in-place ) 修改Module,它可以当成三种函数来使用:function:: to(device=None, dtype=None, non_blocking=False); function:: to(dtype, non_blocking=False); function:: to(tensor, non_blocking=False)。下边展示的是使用方法。
这里直接拷贝官方例子

>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

七、state_dict(self, destination=None, prefix=‘’, keep_vars=False)

函数state_dict的作用是返回一个包含module的所有state的dictionary,而这个字典的Keys对应的就是parameter和buffer的名字names。该函数的源码部分有一个循环可以递归遍历Module中所有的SubModule。

>>> net = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 2))
>>> net.state_dict()
OrderedDict([('0.weight', tensor([[ 0.4792,  0.5772], [ 0.1039, -0.0552]])), 
        ('0.bias', tensor([-0.5175, -0.6469])), 
        ('1.weight', tensor([[-0.5346, -0.0173], [-0.2092,  0.0794]])), 
        ('1.bias', tensor([-0.2150,  0.2323]))])
>>> net.state_dict().keys()
odict_keys(['0.weight', '0.bias', '1.weight', '1.bias'])

八、def train(self, mode=True)和 eval(self)

函数train和函数eval的作用是将Module及其SubModule分别设置为training mode和evaluation mode。这两个函数只对特定的Module有影响,例如Class Dropout、Class BatchNorm。

其他更多见
https://zhuanlan.zhihu.com/p/88712978

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值