梯度下降法(PythonCode)
梯度下降:代码
现在假设只有一个输出单元,我来把这个写成代码。我们还是用 sigmoid 来作为激活函数 f(h)f(h)。
Defining the sigmoid function for activations
定义 sigmoid 激活函数
def sigmoid(x):
return 1/(1+np.exp(-x))
Derivative of the sigmoid function
激活函数的导数
def sigmoid_prime(x):
return sigmoid(x) * (1 - sigmoid(x))
Input data
输入数据
x = np.array([0.1, 0.3])
Target
目标
y = 0.2
Input to output weights
输入到输出的权重
weights = np.array([-0.8, 0.5])
The learning rate, eta in the weight step equation
权重更新的学习率
learnrate = 0.5
#the linear combination performed by the node (h in f(h) and f’(h))
#输入和权重的线性组合
h = x[0]*weights[0] + x[1]*weights[1]
#or h = np.dot(x, weights)
#The neural network output (y-hat)
#神经网络输出
nn_output = sigmoid(h)
#output error (y - y-hat)
#输出误差
error = y - nn_output
#output gradient (f’(h))
#输出梯度
output_grad = sigmoid_prime(h)
#error term (lowercase delta)
error_term = error * output_grad
#Gradient descent step
#梯度下降一步
del_w = [ learnrate * error_term * x[0],
learnrate * error_term * x[1]]
#or del_w = learnrate * error_term * x
import numpy as np
def sigmoid(x):
"""
Calculate sigmoid
"""
return 1/(1+np.exp(-x))
def sigmoid_prime(x):
"""
# Derivative of the sigmoid function
"""
return sigmoid(x) * (1 - sigmoid(x))
learnrate = 0.5
x = np.array([1, 2, 3, 4])
y = np.array(0.5)
# Initial weights
w = np.array([0.5, -0.5, 0.3, 0.1])
### Calculate one gradient descent step for each weight
### Note: Some steps have been consilated, so there are
### fewer variable names than in the above sample code
# TODO: Calculate the node's linear combination of inputs and weights
h = np.dot(w, x)
# TODO: Calculate output of neural network
nn_output = sigmoid(h)
# TODO: Calculate error of neural network
error = y - nn_output
# TODO: Calculate the error term
# Remember, this requires the output gradient, which we haven't
# specifically added a variable for.
error_term = error * sigmoid_prime(h)
# TODO: Calculate change in weights
del_w = learnrate * error_term * x
# or del_w = learnrate * error_term * x
print('Neural Network output:')
print(nn_output)
print('Amount of Error:')
print(error)
print('Change in Weights:')
print(del_w)