torch.autograd.grad()函数用法示例

目录

一、函数解释

二、代码范例(y=x^2)


一、函数解释

如果输入x,输出是y,则求y关于x的导数(梯度):result = \frac{dy}{dx}

def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False,
         only_inputs=True, allow_unused=False):
    r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.

    ``grad_outputs`` should be a sequence of length matching ``output``
    containing the pre-computed gradients w.r.t. each of the outputs. If an
    output doesn't require_grad, then the gradient can be ``None``).

    If ``only_inputs`` is ``True``, the function will only return a list of gradients
    w.r.t the specified inputs. If it's ``False``, then gradient w.r.t. all remaining
    leaves will still be computed, and will be accumulated into their ``.grad``
    attribute.

    Arguments:
        outputs (sequence of Tensor): outputs of the differentiated function.
        inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
            returned (and not accumulated into ``.grad``).
        grad_outputs (sequence of Tensor): Gradients w.r.t. each output.
            None values can be specified for scalar Tensors or ones that don't require
            grad. If a None value would be acceptable for all grad_tensors, then this
            argument is optional. Default: None.
        retain_graph (bool, optional): If ``False``, the graph used to compute the grad
            will be freed. Note that in nearly all cases setting this option to ``True``
            is not needed and often can be worked around in a much more efficient
            way. Defaults to the value of ``create_graph``.
        create_graph (bool, optional): If ``True``, graph of the derivative will
            be constructed, allowing to compute higher order derivative products.
            Default: ``False``.
        allow_unused (bool, optional): If ``False``, specifying inputs that were not
            used when computing outputs (and therefore their grad is always zero)
            is an error. Defaults to ``False``.
    """
    if not only_inputs:
        warnings.warn("only_inputs argument is deprecated and is ignored now "
                      "(defaults to True). To accumulate gradient for other "
                      "parts of the graph, please use torch.autograd.backward.")

    outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
    inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
    if grad_outputs is None:
        grad_outputs = [None] * len(outputs)
    elif isinstance(grad_outputs, torch.Tensor):
        grad_outputs = [grad_outputs]
    else:
        grad_outputs = list(grad_outputs)

    grad_outputs = _make_grads(outputs, grad_outputs)
    if retain_graph is None:
        retain_graph = create_graph

    return Variable._execution_engine.run_backward(
        outputs, grad_outputs, retain_graph, create_graph,
        inputs, allow_unused)

二、代码范例(y=x^2)

import torch

x = torch.randn(3, 4).requires_grad_(True)
for i in range(3):
    for j in range(4):
        x[i][j] = i + j
y = x ** 2
print(x)
print(y)
weight = torch.ones(y.size())
print(weight)
dydx = torch.autograd.grad(outputs=y,
                           inputs=x,
                           grad_outputs=weight,
                           retain_graph=True,
                           create_graph=True,
                           only_inputs=True)
"""(x**2)' = 2*x """
print(dydx[0])
d2ydx2 = torch.autograd.grad(outputs=dydx[0],
                             inputs=x,
                             grad_outputs=weight,
                             retain_graph=True,
                             create_graph=True,
                             only_inputs=True)
print(d2ydx2[0])

x是:

tensor([[0., 1., 2., 3.],
        [1., 2., 3., 4.],
        [2., 3., 4., 5.]], grad_fn=<CopySlices>)

y = x的平方:

tensor([[ 0.,  1.,  4.,  9.],
        [ 1.,  4.,  9., 16.],
        [ 4.,  9., 16., 25.]], grad_fn=<PowBackward0>)

weight:

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

dydx就是\frac{dy}{dx} = 2x(一阶导数),得到结果还需要乘以weight:

tensor([[ 0.,  2.,  4.,  6.],
        [ 2.,  4.,  6.,  8.],
        [ 4.,  6.,  8., 10.]], grad_fn=<ThMulBackward>)

d2ydx2就是\frac{d^{2}y}{dx^{2}} = (2x)'=2(二阶导数),得到结果还需要乘以weight: 

tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]], grad_fn=<ThMulBackward>)

是不是很简单呢~

发布了127 篇原创文章 · 获赞 979 · 访问量 142万+

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览