pytorch中的梯度检查,gradcheck()函数

gradcheck():

gradcheck(func, inputs, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, check_sparse_nnz=False)
Check gradients computed via small finite differences against analytical
gradients w.r.t. tensors in :attr:inputs that are of floating point type
and with requires_grad=True.

The check between numerical and analytical gradients uses :func:`~torch.allclose`.

.. note::
    The default values are designed for :attr:`input` of double precision.
    This check will likely fail if :attr:`input` is of less precision, e.g.,
    ``FloatTensor``.

.. warning::
   If any checked tensor in :attr:`input` has overlapping memory, i.e.,
   different indices pointing to the same memory address (e.g., from
   :func:`torch.expand`), this check will likely fail because the numerical
   gradients computed by point perturbation at such indices will change
   values at all other indices that share the same memory address.

Args:
    func (function): a Python function that takes Tensor inputs and returns
        a Tensor or a tuple of Tensors
    inputs (tuple of Tensor or Tensor): inputs to the function
    eps (float, optional): perturbation for finite differences
    atol (float, optional): absolute tolerance
    rtol (float, optional): relative tolerance
    raise_exception (bool, optional): indicating whether to raise an exception if
        the check fails. The exception gives more information about the
        exact nature of the failure. This is helpful when debugging gradchecks.
    check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
        and for any SparseTensor at input, gradcheck will perform check at nnz positions only.

Returns:
    True if all differences satisfy allclose condition

例子

	from torch.autograd import gradcheck
	inputs = Variable(torch.randn(1,1,2,2), requires_grad=True)
	conv = nn.Conv2d(1,1,1,1)
	test = gradcheck(lambda x: conv(x),(inputs,))
	print(test)

输出:
numerical:tensor([[0.0596, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0447, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0596, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0447]])
analytical:tensor([[0.0483, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0483, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0483, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0483]])

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值