梯度检验---实例代码

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jiede1/article/details/76973494
在完成神经网络或softmax回归时,需要进行梯度检验。实际上,所有利用到求代价函数的偏导数的算法都需要利用到梯度检验。通过梯度检验,可以判断求得的偏导数是否正确。

梯度检验最核心的目的是,检验下面的式子是否成立:
在这里插入图片描述

其中,J是代价函数,g是代价函数的求导值。

至于更多梯度检验的说明,读者可以参考这篇文章

代码如下:

def simple_quadratic_function(x):
    value = x[0] ** 2 + 3 * x[0] * x[1]

    grad = np.zeros(shape=2, dtype=np.float32)
    grad[0] = 2 * x[0] + 3 * x[1]
    grad[1] = 3 * x[0]

    return value, grad


# theta: a vector of parameters
# J: a function that outputs a real-number. Calling y = J(theta) will return the
# function value at theta.

def compute_gradient(J, theta):
    epsilon = 0.0001

    gradient = np.zeros(theta.shape)
    for i in range(theta.shape[0]):
        theta_epsilon_plus = np.array(theta, dtype=np.float64)
        theta_epsilon_plus[i] = theta[i] + epsilon
        theta_epsilon_minus = np.array(theta, dtype=np.float64)
        theta_epsilon_minus[i] = theta[i] - epsilon

        gradient[i] = (J(theta_epsilon_plus)[0] - J(theta_epsilon_minus)[0]) / (2 * epsilon)
        if i % 100 == 0:
            print ("Computing gradient for input:",i)

    return gradient


# This code can be used to check your numerical gradient implementation
# in computeNumericalGradient.m
# It analytically evaluates the gradient of a very simple function called
# simpleQuadraticFunction (see below) and compares the result with your numerical
# solution. Your numerical gradient implementation is incorrect if
# your numerical solution deviates too much from the analytical solution.

def check_gradient():
    x = np.array([4, 10], dtype=np.float64)
    (value, grad) = simple_quadratic_function(x)

    num_grad = compute_gradient(simple_quadratic_function, x)
    print (num_grad, grad)
    print ("The above two columns you get should be very similar.\n" \
          "(Left-Your Numerical Gradient, Right-Analytical Gradient)\n")

    diff = np.linalg.norm(num_grad - grad) / np.linalg.norm(num_grad + grad)
    print (diff)
    print ("Norm of the difference between numerical and analytical num_grad (should be < 1e-9)\n")

check_gradient()
--------------------- 

作者:Jiede1
来源:CSDN
原文:https://blog.csdn.net/jiede1/article/details/76973494
版权声明:本文为博主原创文章,转载请附上博文链接!

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值