pytorch_course3

import torch
torch.__version__
'1.12.1'

使用PyTorch计算梯度数值

PyTorch的Autograd模块实现了深度学习的算法中的向传播求导数,在张量(Tensor类)上的所有操作,Autograd都能为他们自动提供微分,简化了手动计算导数的复杂过程。
在0.4以前的版本中,pytorch使用Variable类来自动计算所有的梯度,Variable类主要包含三个属性

  • data: 保存Variable所包含的Tensor
  • grad: 保存data对应的梯度,grad也是个Variable,而不是Tensor,它和data的形状一样
  • grad_fn: 指向一个Function对象,这个Function用来反向传播计算输入的梯度

从0.4起, Variable 正式合并入Tensor类,通过Variable嵌套实现的自动微分功能已经整合进入了Tensor类中。虽然为了代码的兼容性还是可以使用Variable(tensor)这种方式进行嵌套,但是这个操作其实什么都没做。

所以,以后的代码建议直接使用Tensor类进行操作,因为官方文档中已经将Variable设置成过期模块。

要想通过Tensor类本身就支持了使用autograd功能,只需要设置.requries_grad=True

Variable类中的的grad和grad_fn属性已经整合进入了Tensor类中

Autograd

在张量创建时,通过设置 requires_grad 标识为Ture来告诉Pytorch需要对该张量进行自动求导,PyTorch会记录该张量的每一步操作历史并自动计算

# 首先要引入相关的包
import torch
import numpy as np
#打印一下版本
torch.__version__
'1.12.1'
x = torch.rand(5,5, requires_grad=True)
x
tensor([[0.4807, 0.2961, 0.3068, 0.3815, 0.5506],
        [0.8147, 0.4751, 0.0330, 0.5118, 0.2246],
        [0.6915, 0.0553, 0.6438, 0.8483, 0.2649],
        [0.3101, 0.7949, 0.9997, 0.9636, 0.4530],
        [0.3466, 0.8072, 0.3508, 0.5186, 0.1210]], requires_grad=True)
y = torch.rand(5,5, requires_grad=True)
y
tensor([[0.8328, 0.9649, 0.3337, 0.4130, 0.1008],
        [0.9288, 0.5346, 0.3576, 0.8023, 0.9534],
        [0.0559, 0.0078, 0.6855, 0.3787, 0.2493],
        [0.4254, 0.5334, 0.0279, 0.3771, 0.2242],
        [0.7845, 0.5027, 0.5006, 0.7460, 0.5400]], requires_grad=True)

PyTorch会自动追踪和记录对与张量的所有操作,当计算完成后调用.backward()方法自动计算梯度并且将计算结果保存到grad属性中。

z = torch.sum(x+y)
z
tensor(24.5049, grad_fn=<SumBackward0>)

在张量进行操作后,grad_fn已经被赋予了一个新的函数,这个函数引用了一个创建了这个Tensor类的Function对象。
Tensor和Function互相连接生成了一个非循环图,它记录并且编码了完整的计算历史。每个张量都有一个.grad_fn属性,如果这个张量是用户手动创建的那么这个张量的grad_fn是None。

下面我们来调用反向传播函数,计算其梯度

简单的自动求导

z.backward()
print(x.grad, y.grad)
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]) tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

如果Tensor类表示的是一个标量(即它包含一个元素的张量),则不需要为backward()指定任何参数,但是如果它有更多的元素,则需要指定一个gradient参数,它是形状匹配的张量。
以上的 z.backward()相当于是z.backward(torch.tensor(1.))的简写。
这种参数常出现在图像分类中的单标签分类,输出一个标量代表图像的标签。

复杂的自动求导

x = torch.rand(5, 5, requires_grad=True)
y = torch.rand(5, 5, requires_grad=True)
z= x**2+y**3
z
tensor([[0.1840, 0.5713, 0.8338, 0.8437, 0.1953],
        [0.5606, 0.6024, 0.0734, 1.3823, 0.2659],
        [0.5723, 0.2056, 0.2090, 0.0324, 0.2047],
        [0.0553, 0.7134, 0.4190, 0.8945, 0.9590],
        [0.2629, 0.3860, 1.4279, 0.2361, 0.5961]], grad_fn=<AddBackward0>)
#我们的返回值不是一个标量,所以需要输入一个大小相同的张量作为参数,这里我们用ones_like函数根据x生成一个张量
z.backward(torch.ones_like(x))
x.grad
tensor([[0.6485, 1.4734, 1.7984, 1.8328, 0.3542],
        [1.4908, 1.4917, 0.5197, 1.5503, 0.9997],
        [1.4083, 0.7759, 0.9028, 0.0373, 0.5013],
        [0.3622, 1.5190, 1.1054, 1.1834, 0.6520],
        [0.4989, 1.2394, 1.8260, 0.9714, 1.4498]])

我们可以使用with torch.no_grad()上下文管理器临时禁止对已设置requires_grad=True的张量进行自动求导。这个方法在测试集计算准确率的时候会经常用到,例如:

with torch.no_grad():
    print((x+y*2).requires_grad)
False

使用.no_grad()进行嵌套后,代码不会跟踪历史记录,也就是说保存的这部分记录会减少内存的使用量并且会加快少许的运算速度。

Autograd 过程解析

为了说明Pytorch的自动求导原理,我们来尝试分析一下PyTorch的源代码,虽然Pytorch的 Tensor和 TensorBase都是使用CPP来实现的,但是可以使用一些Python的一些方法查看这些对象在Python的属性和状态。

Python的 dir() 返回参数的属性、方法列表。z是一个Tensor变量,看看里面有哪些成员变量。

dir(z)
['H',
 'T',
 '__abs__',
 '__add__',
 '__and__',
 '__array__',
 '__array_priority__',
 '__array_wrap__',
 '__bool__',
 '__class__',
 '__complex__',
 '__contains__',
 '__deepcopy__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__div__',
 '__dlpack__',
 '__dlpack_device__',
 '__doc__',
 '__eq__',
 '__float__',
 '__floordiv__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__iadd__',
 '__iand__',
 '__idiv__',
 '__ifloordiv__',
 '__ilshift__',
 '__imod__',
 '__imul__',
 '__index__',
 '__init__',
 '__init_subclass__',
 '__int__',
 '__invert__',
 '__ior__',
 '__ipow__',
 '__irshift__',
 '__isub__',
 '__iter__',
 '__itruediv__',
 '__ixor__',
 '__le__',
 '__len__',
 '__long__',
 '__lshift__',
 '__lt__',
 '__matmul__',
 '__mod__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__nonzero__',
 '__or__',
 '__pos__',
 '__pow__',
 '__radd__',
 '__rand__',
 '__rdiv__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__rfloordiv__',
 '__rlshift__',
 '__rmatmul__',
 '__rmod__',
 '__rmul__',
 '__ror__',
 '__rpow__',
 '__rrshift__',
 '__rshift__',
 '__rsub__',
 '__rtruediv__',
 '__rxor__',
 '__setattr__',
 '__setitem__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__sub__',
 '__subclasshook__',
 '__torch_dispatch__',
 '__torch_function__',
 '__truediv__',
 '__weakref__',
 '__xor__',
 '_addmm_activation',
 '_autocast_to_full_precision',
 '_autocast_to_reduced_precision',
 '_backward_hooks',
 '_base',
 '_cdata',
 '_coalesced_',
 '_conj',
 '_conj_physical',
 '_dimI',
 '_dimV',
 '_fix_weakref',
 '_grad',
 '_grad_fn',
 '_indices',
 '_is_view',
 '_is_zerotensor',
 '_make_subclass',
 '_make_wrapper_subclass',
 '_neg_view',
 '_nested_tensor_layer_norm',
 '_nnz',
 '_python_dispatch',
 '_reduce_ex_internal',
 '_storage',
 '_to_dense',
 '_update_names',
 '_values',
 '_version',
 'abs',
 'abs_',
 'absolute',
 'absolute_',
 'acos',
 'acos_',
 'acosh',
 'acosh_',
 'add',
 'add_',
 'addbmm',
 'addbmm_',
 'addcdiv',
 'addcdiv_',
 'addcmul',
 'addcmul_',
 'addmm',
 'addmm_',
 'addmv',
 'addmv_',
 'addr',
 'addr_',
 'adjoint',
 'align_as',
 'align_to',
 'all',
 'allclose',
 'amax',
 'amin',
 'aminmax',
 'angle',
 'any',
 'apply_',
 'arccos',
 'arccos_',
 'arccosh',
 'arccosh_',
 'arcsin',
 'arcsin_',
 'arcsinh',
 'arcsinh_',
 'arctan',
 'arctan2',
 'arctan2_',
 'arctan_',
 'arctanh',
 'arctanh_',
 'argmax',
 'argmin',
 'argsort',
 'argwhere',
 'as_strided',
 'as_strided_',
 'as_subclass',
 'asin',
 'asin_',
 'asinh',
 'asinh_',
 'atan',
 'atan2',
 'atan2_',
 'atan_',
 'atanh',
 'atanh_',
 'backward',
 'baddbmm',
 'baddbmm_',
 'bernoulli',
 'bernoulli_',
 'bfloat16',
 'bincount',
 'bitwise_and',
 'bitwise_and_',
 'bitwise_left_shift',
 'bitwise_left_shift_',
 'bitwise_not',
 'bitwise_not_',
 'bitwise_or',
 'bitwise_or_',
 'bitwise_right_shift',
 'bitwise_right_shift_',
 'bitwise_xor',
 'bitwise_xor_',
 'bmm',
 'bool',
 'broadcast_to',
 'byte',
 'cauchy_',
 'ccol_indices',
 'cdouble',
 'ceil',
 'ceil_',
 'cfloat',
 'chalf',
 'char',
 'cholesky',
 'cholesky_inverse',
 'cholesky_solve',
 'chunk',
 'clamp',
 'clamp_',
 'clamp_max',
 'clamp_max_',
 'clamp_min',
 'clamp_min_',
 'clip',
 'clip_',
 'clone',
 'coalesce',
 'col_indices',
 'conj',
 'conj_physical',
 'conj_physical_',
 'contiguous',
 'copy_',
 'copysign',
 'copysign_',
 'corrcoef',
 'cos',
 'cos_',
 'cosh',
 'cosh_',
 'count_nonzero',
 'cov',
 'cpu',
 'cross',
 'crow_indices',
 'cuda',
 'cummax',
 'cummin',
 'cumprod',
 'cumprod_',
 'cumsum',
 'cumsum_',
 'data',
 'data_ptr',
 'deg2rad',
 'deg2rad_',
 'dense_dim',
 'dequantize',
 'det',
 'detach',
 'detach_',
 'device',
 'diag',
 'diag_embed',
 'diagflat',
 'diagonal',
 'diagonal_scatter',
 'diff',
 'digamma',
 'digamma_',
 'dim',
 'dist',
 'div',
 'div_',
 'divide',
 'divide_',
 'dot',
 'double',
 'dsplit',
 'dtype',
 'eig',
 'element_size',
 'eq',
 'eq_',
 'equal',
 'erf',
 'erf_',
 'erfc',
 'erfc_',
 'erfinv',
 'erfinv_',
 'exp',
 'exp2',
 'exp2_',
 'exp_',
 'expand',
 'expand_as',
 'expm1',
 'expm1_',
 'exponential_',
 'fill_',
 'fill_diagonal_',
 'fix',
 'fix_',
 'flatten',
 'flip',
 'fliplr',
 'flipud',
 'float',
 'float_power',
 'float_power_',
 'floor',
 'floor_',
 'floor_divide',
 'floor_divide_',
 'fmax',
 'fmin',
 'fmod',
 'fmod_',
 'frac',
 'frac_',
 'frexp',
 'gather',
 'gcd',
 'gcd_',
 'ge',
 'ge_',
 'geometric_',
 'geqrf',
 'ger',
 'get_device',
 'grad',
 'grad_fn',
 'greater',
 'greater_',
 'greater_equal',
 'greater_equal_',
 'gt',
 'gt_',
 'half',
 'hardshrink',
 'has_names',
 'heaviside',
 'heaviside_',
 'histc',
 'histogram',
 'hsplit',
 'hypot',
 'hypot_',
 'i0',
 'i0_',
 'igamma',
 'igamma_',
 'igammac',
 'igammac_',
 'imag',
 'index_add',
 'index_add_',
 'index_copy',
 'index_copy_',
 'index_fill',
 'index_fill_',
 'index_put',
 'index_put_',
 'index_reduce',
 'index_reduce_',
 'index_select',
 'indices',
 'inner',
 'int',
 'int_repr',
 'inverse',
 'ipu',
 'is_coalesced',
 'is_complex',
 'is_conj',
 'is_contiguous',
 'is_cuda',
 'is_distributed',
 'is_floating_point',
 'is_inference',
 'is_ipu',
 'is_leaf',
 'is_meta',
 'is_mkldnn',
 'is_mps',
 'is_neg',
 'is_nested',
 'is_nonzero',
 'is_ort',
 'is_pinned',
 'is_quantized',
 'is_same_size',
 'is_set_to',
 'is_shared',
 'is_signed',
 'is_sparse',
 'is_sparse_csr',
 'is_vulkan',
 'is_xpu',
 'isclose',
 'isfinite',
 'isinf',
 'isnan',
 'isneginf',
 'isposinf',
 'isreal',
 'istft',
 'item',
 'kron',
 'kthvalue',
 'layout',
 'lcm',
 'lcm_',
 'ldexp',
 'ldexp_',
 'le',
 'le_',
 'lerp',
 'lerp_',
 'less',
 'less_',
 'less_equal',
 'less_equal_',
 'lgamma',
 'lgamma_',
 'log',
 'log10',
 'log10_',
 'log1p',
 'log1p_',
 'log2',
 'log2_',
 'log_',
 'log_normal_',
 'log_softmax',
 'logaddexp',
 'logaddexp2',
 'logcumsumexp',
 'logdet',
 'logical_and',
 'logical_and_',
 'logical_not',
 'logical_not_',
 'logical_or',
 'logical_or_',
 'logical_xor',
 'logical_xor_',
 'logit',
 'logit_',
 'logsumexp',
 'long',
 'lstsq',
 'lt',
 'lt_',
 'lu',
 'lu_solve',
 'mH',
 'mT',
 'map2_',
 'map_',
 'masked_fill',
 'masked_fill_',
 'masked_scatter',
 'masked_scatter_',
 'masked_select',
 'matmul',
 'matrix_exp',
 'matrix_power',
 'max',
 'maximum',
 'mean',
 'median',
 'min',
 'minimum',
 'mm',
 'mode',
 'moveaxis',
 'movedim',
 'msort',
 'mul',
 'mul_',
 'multinomial',
 'multiply',
 'multiply_',
 'mv',
 'mvlgamma',
 'mvlgamma_',
 'name',
 'names',
 'nan_to_num',
 'nan_to_num_',
 'nanmean',
 'nanmedian',
 'nanquantile',
 'nansum',
 'narrow',
 'narrow_copy',
 'ndim',
 'ndimension',
 'ne',
 'ne_',
 'neg',
 'neg_',
 'negative',
 'negative_',
 'nelement',
 'new',
 'new_empty',
 'new_empty_strided',
 'new_full',
 'new_ones',
 'new_tensor',
 'new_zeros',
 'nextafter',
 'nextafter_',
 'nonzero',
 'norm',
 'normal_',
 'not_equal',
 'not_equal_',
 'numel',
 'numpy',
 'orgqr',
 'ormqr',
 'outer',
 'output_nr',
 'permute',
 'pin_memory',
 'pinverse',
 'polygamma',
 'polygamma_',
 'positive',
 'pow',
 'pow_',
 'prelu',
 'prod',
 'put',
 'put_',
 'q_per_channel_axis',
 'q_per_channel_scales',
 'q_per_channel_zero_points',
 'q_scale',
 'q_zero_point',
 'qr',
 'qscheme',
 'quantile',
 'rad2deg',
 'rad2deg_',
 'random_',
 'ravel',
 'real',
 'reciprocal',
 'reciprocal_',
 'record_stream',
 'refine_names',
 'register_hook',
 'reinforce',
 'relu',
 'relu_',
 'remainder',
 'remainder_',
 'rename',
 'rename_',
 'renorm',
 'renorm_',
 'repeat',
 'repeat_interleave',
 'requires_grad',
 'requires_grad_',
 'reshape',
 'reshape_as',
 'resize',
 'resize_',
 'resize_as',
 'resize_as_',
 'resize_as_sparse_',
 'resolve_conj',
 'resolve_neg',
 'retain_grad',
 'retains_grad',
 'roll',
 'rot90',
 'round',
 'round_',
 'row_indices',
 'rsqrt',
 'rsqrt_',
 'scatter',
 'scatter_',
 'scatter_add',
 'scatter_add_',
 'scatter_reduce',
 'scatter_reduce_',
 'select',
 'select_scatter',
 'set_',
 'sgn',
 'sgn_',
 'shape',
 'share_memory_',
 'short',
 'sigmoid',
 'sigmoid_',
 'sign',
 'sign_',
 'signbit',
 'sin',
 'sin_',
 'sinc',
 'sinc_',
 'sinh',
 'sinh_',
 'size',
 'slice_scatter',
 'slogdet',
 'smm',
 'softmax',
 'solve',
 'sort',
 'sparse_dim',
 'sparse_mask',
 'sparse_resize_',
 'sparse_resize_and_clear_',
 'split',
 'split_with_sizes',
 'sqrt',
 'sqrt_',
 'square',
 'square_',
 'squeeze',
 'squeeze_',
 'sspaddmm',
 'std',
 'stft',
 'storage',
 'storage_offset',
 'storage_type',
 'stride',
 'sub',
 'sub_',
 'subtract',
 'subtract_',
 'sum',
 'sum_to_size',
 'svd',
 'swapaxes',
 'swapaxes_',
 'swapdims',
 'swapdims_',
 'symeig',
 't',
 't_',
 'take',
 'take_along_dim',
 'tan',
 'tan_',
 'tanh',
 'tanh_',
 'tensor_split',
 'tile',
 'to',
 'to_dense',
 'to_mkldnn',
 'to_padded_tensor',
 'to_sparse',
 'to_sparse_bsc',
 'to_sparse_bsr',
 'to_sparse_coo',
 'to_sparse_csc',
 'to_sparse_csr',
 'tolist',
 'topk',
 'trace',
 'transpose',
 'transpose_',
 'triangular_solve',
 'tril',
 'tril_',
 'triu',
 'triu_',
 'true_divide',
 'true_divide_',
 'trunc',
 'trunc_',
 'type',
 'type_as',
 'unbind',
 'unflatten',
 'unfold',
 'uniform_',
 'unique',
 'unique_consecutive',
 'unsafe_chunk',
 'unsafe_split',
 'unsafe_split_with_sizes',
 'unsqueeze',
 'unsqueeze_',
 'values',
 'var',
 'vdot',
 'view',
 'view_as',
 'vsplit',
 'where',
 'xlogy',
 'xlogy_',
 'xpu',
 'zero_']

返回很多,我们直接排除掉一些Python中特殊方法(以__开头和结束的)和私有方法(以_开头的,直接看几个比较主要的属性:

.is_leaf:记录是否是叶子节点。通过这个属性来确定这个变量的类型

在官方文档中所说的**“graph leaves”,“leaf variables”,都是指像x,y这样的手动创建的、而非运算得到的变量,这些变量成为创建变量。
像z这样的,是通过计算后得到的结果称为结果变量。**

一个变量是创建变量还是结果变量是通过.is_leaf来获取的。

print("x.is_leaf="+str(x.is_leaf))
print("z.is_leaf="+str(z.is_leaf))
x.is_leaf=True
z.is_leaf=False

x是手动创建的没有通过计算,所以他被认为是一个叶子节点也就是一个创建变量,而z是通过x与y的一系列计算得到的,所以不是叶子结点也就是结果变量。

为什么我们执行z.backward()方法会更新x.grad和y.grad呢?

.grad_fn属性记录的就是这部分的操作,虽然.backward()方法也是CPP实现的,但是可以通过Python来进行简单的探索。

grad_fn:记录并且编码了完整的计算历史

z.grad_fn
<AddBackward0 at 0x1d4ab389100>

grad_fn是一个AddBackward0类型的变量 AddBackward0这个类也是用Cpp来写的,但是我们从名字里就能够大概知道,他是加法(ADD)的反反向传播(Backward),看看里面有些什么东西

dir(z.grad_fn)
['__call__',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_register_hook_dict',
 '_saved_alpha',
 'metadata',
 'name',
 'next_functions',
 'register_hook',
 'requires_grad']

next_functions就是grad_fn的精华

z.grad_fn.next_functions
((<PowBackward0 at 0x1d4ab397970>, 0), (<PowBackward0 at 0x1d4ab384880>, 0))

next_functions是一个tuple of tuple of PowBackward0 and int。

为什么是2个tuple ?

因为我们的操作是z= x2+y3 刚才的AddBackward0是相加,而前面的操作是乘方 PowBackward0。tuple第一个元素就是x相关的操作记录

xg = z.grad_fn.next_functions[0]
xg
(<PowBackward0 at 0x1d4ab397970>, 0)
xg = z.grad_fn.next_functions[0][0]
xg
<PowBackward0 at 0x1d4ab397970>
dir(xg)
['__call__',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_raw_saved_self',
 '_register_hook_dict',
 '_saved_exponent',
 '_saved_self',
 'metadata',
 'name',
 'next_functions',
 'register_hook',
 'requires_grad']
x_leaf=xg.next_functions[0][0]
x_leaf
<AccumulateGrad at 0x1d4b19c2fd0>
type(x_leaf)
AccumulateGrad

在PyTorch的反向图计算中,AccumulateGrad类型代表的就是叶子节点类型,也就是计算图终止节点。AccumulateGrad类中有一个.variable属性指向叶子节点。

x_leaf.variable
tensor([[0.3243, 0.7367, 0.8992, 0.9164, 0.1771],
        [0.7454, 0.7458, 0.2599, 0.7751, 0.4998],
        [0.7042, 0.3879, 0.4514, 0.0187, 0.2506],
        [0.1811, 0.7595, 0.5527, 0.5917, 0.3260],
        [0.2495, 0.6197, 0.9130, 0.4857, 0.7249]], requires_grad=True)

这个.variable的属性就是我们的生成的变量x

print("x_leaf.variable的id:"+str(id(x_leaf.variable)))
print("x的id:"+str(id(x)))
x_leaf.variable的id:2012916808976
x的id:2012916808976

这样整个规程就很清晰了:

  • 当我们执行z.backward()的时候。这个操作将调用z里面的grad_fn这个属性,执行求导的操作。

  • 这个操作将遍历grad_fn的next_functions,然后分别取出里面的Function(AccumulateGrad),执行求导操作。这部分是一个递归的过程直到最后类型为叶子节点。

  • 计算出结果以后,将结果保存到他们对应的variable 这个变量所引用的对象(x和y)的 grad这个属性里面。

  • 求导结束。所有的叶节点的grad变量都得到了相应的更新

最终当我们执行完c.backward()之后,a和b里面的grad值就得到了更新。

扩展Autograd

如果需要自定义autograd扩展新的功能,就需要扩展Function类。因为Function使用autograd来计算结果和梯度,并对操作历史进行编码。

在Function类中最主要的方法就是forward()和backward()他们分别代表了前向传播和反向传播。

一个自定义的Function需要一下三个方法:

__init__ (optional):如果这个操作需要额外的参数则需要定义这个Function的构造函数,不需要的话可以忽略。  

forward():执行前向传播的计算代码  

backward():反向传播时梯度计算的代码。 参数的个数和forward返回值的个数一样,每个参数代表传回到此操作的梯度
# 引入Function便于扩展
from torch.autograd.function import Function
# 定义一个乘以常数的操作(输入参数是张量)
# 方法必须是静态方法,所以要加上@staticmethod 
class MulConstant(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        # ctx 用来保存信息这里类似self,并且ctx的属性可以在backward中调用
        ctx.constant=constant
        return tensor *constant
    
    @staticmethod
    def backward(ctx, grad_output):
        # 返回的参数要与输入的参数一样.
        # 第一个输入为3x3的张量,第二个为一个常数
        # 常数的梯度必须是 None.
        return grad_output, None
a=torch.rand(3,3,requires_grad=True)
b = MulConstant.apply(a, 5)
print("a:"+str(a))
print("b:"+str(b)) # b为a的元素乘以5
a:tensor([[0.1500, 0.1894, 0.6854],
        [0.0230, 0.2263, 0.7670],
        [0.7327, 0.1445, 0.0670]], requires_grad=True)
b:tensor([[0.7502, 0.9468, 3.4271],
        [0.1152, 1.1315, 3.8349],
        [3.6637, 0.7226, 0.3352]], grad_fn=<MulConstantBackward>)

反向传播,返回值不是标量,所以backward方法需要参数

b.backward(torch.ones_like(a))
a.grad
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值