最小二乘法懂的都懂,不再解释
分部解释一下代码块:
这个为处理数据的步骤,根据具体情况具体分析,最后获得x_data和y_data就可以了
df = pd.read_excel('data2.xls', sheetname=0,header=None)#从第一个 sheet读取
a=np.array(df)
print(a)
a = a[a[:,0].argsort()]
print(a)
print(np.shape(a))
x_data=a[:,0]
y_data=a[:,1]
接下来为预处理数据步骤(变量名比较特别,不用在意)
使用的是经典的max-mix归一化,这对数据的拟合有极大的帮助~
ke=max(x_data)
da=min(x_data)
for i in range(np.shape(a)[0]): #归一化处理
x_data[i]=(x_data[i]-da)/(ke-da)
接下来为参数的设置,大家一般初始化喜欢设置为0,我设置为均匀分布,这个影响不大~
还有选取梯度下降优化器,设置学习率~
W1 = tf.Variable(tf.random_uniform([1]))
W2 = tf.Variable(tf.random_uniform([1]))
W3 = tf.Variable(tf.random_uniform([1]))
b = tf.Variable(tf.zeros([1]))
y = W1 * x_data+W2 * np.multiply(x_data,x_data)+W3 *np.multiply( np.multiply(x_data,x_data),x_data) + b
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
其他没什么特别了,附上代码~
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_excel('data2.xls', sheetname=0,header=None)#从第一个 sheet读取
a=np.array(df)
print(a)
a = a[a[:,0].argsort()]
print(a)
print(np.shape(a))
x_data=a[:,0]
y_data=a[:,1]
ke=max(x_data)
da=min(x_data)
for i in range(np.shape(a)[0]): #归一化处理
x_data[i]=(x_data[i]-da)/(ke-da)
plt.scatter(x_data, y_data)
plt.show()
W1 = tf.Variable(tf.random_uniform([1]))
W2 = tf.Variable(tf.random_uniform([1]))
W3 = tf.Variable(tf.random_uniform([1]))
b = tf.Variable(tf.zeros([1]))
y = W1 * x_data+W2 * np.multiply(x_data,x_data)+W3 *np.multiply( np.multiply(x_data,x_data),x_data) + b
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 执行50000次训练
for step in range(1000):
sess.run(train)
W1, W2, W3, b = sess.run([W1, W2, W3, b])
print(W1)
print(b)
plt.scatter(x_data, y_data)
plt.plot(x_data, W1 * x_data+W2 * np.multiply(x_data,x_data)+W3 *np.multiply( np.multiply(x_data,x_data),x_data) + b, c='r')
plt.show()