【Pytorch学习】--必要知识系列

Brief

本博客直接是对pytorch的中文文档的学习和理解。
这里是第一部分的内容。主要包括以下的5个方面

  • 自动求导机制
  • CUDA语义
  • 扩展pytorch
  • 多进程最佳实践
  • 序列化语义

1 自动求导机制

1.1从后向中排除子图

  1. 什么是子图,为什么要排除子图

我的理解是,在一个神经网路结构中,一个计算图也就是一个前向计算的过程,我们在BP的过程中可能需要对某些子图的权重不希望它更新,包括有以下的情形:

(1)dropout掉的神经元
(2)固定前面部分权重,只是对后面的网络结构更新对应的权重
(3)在联合任务中,希望Loss只对其中一部分权重更新有效。
在这里插入图片描述

  1. 两个重要的排除子图的标志
    (1)每个变量都有两个标志:requires_gradvolatile,只有把这两个标志其中一个设置为true的时候是可以执行梯度更新的。(对于冻结某一部分权重很有用)
    (2)volatilerequires_grad更建议在inference模式下使用
  2. 自动求导如何编码历史信息

整个图在每次迭代时都是从头开始重新创建的,这就允许使用任意的Python控制流语句,这样可以在每次迭代时改变图的整体形状和大小。
4. Variable上的In-place操作

In-place操作指的是原位操作,不增加新的内存。

文档中表示不建议使用。过~

2 CUDA语义

2. 1CUDA语义

torch.cuda属性显示的是当前tensor所使用的GPU号。可以使用torch.cuda.device进行修改。
在一般的情形下,一旦设定的GPU,就需要都在该GOU上进行数据操作,但是可以使用_copy()函数操作。

3 扩展

3.1 扩展torch.autograd

需要实现以下的三个函数:

  • init (optional)
  • forward()
  • backward()
    这也就是重新定义一个计算梯度的方法,后续如果用到再细看吧,这里暂且不学习。
    forward()的参数只能是Variable。函数的返回值既可以是 Variable也可以是Variables的tuple。
    下面试一个Liner的重新定义:(注意需要继承的是Function函数)
# Inherit from Function
class Linear(Function):

    # bias is an optional argument
    def forward(self, input, weight, bias=None):
        self.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):
        input, weight, bias = self.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if self.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if self.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and self.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

3.2 扩展 torch.nn

也就是我们最常用的模块继承,不细看了

5 序列化语义

5.1 保存模型的方法

  • 只保存和加载模型参数
torch.save(the_model.state_dict(), PATH)

加载:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
  • 保存和加载整个模型:
torch.save(the_model, PATH)
加载:
the_model = torch.load(PATH)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值