使用的是pycharm中的jupyter notebook ,代码如下
#%%
%matplotlib notebook
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
df = pd.read_csv(r"C:\Users\Administrator\Desktop\boston.csv", header = 0)
# print(df.describe())
df = df.values
df = np.array(df)
print(df)
for i in range(12):
df[:, i]=df[:, i]/(df[:, i].max()-df[:, i].min())
x_data = df[:,:12]
y_data = df[:,12]
#%%
print(x_data,'\n shape=', x_data.shape)
print(y_data,'\n shape=', y_data.shape)
#%%
x = tf.placeholder(tf.float32, [None,12], name="X")
y = tf.placeholder(tf.float32, [None,1], name="Y")
with tf.name_scope("Model"):
w = tf.Variable(tf.random_normal([12,1], stddev=0.01), name="W")
b = tf.Variable(1.0, name="b")
def model(x, w, b):
return tf.matmul(x, w) + b
pred = model(x, w, b)
#%%
train_epochs = 50
learning_rate = 0.02
#%%
with tf.name_scope("LossFunction"):
loss_function = tf.reduce_mean(tf.pow(y-pred, 2))
#%%
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
#%%
sess = tf.Session()
init = tf.global_variables_initializer()
#%%
sess.run(init)
#%%
loss_list = []
for epoch in range (train_epochs):
loss_sum = 0.0
for xs, ys in zip(x_data, y_data):
xs = xs.reshape(1,12)
ys = ys.reshape(1,1)
_, loss = sess.run([optimizer,loss_function], feed_dict={x: xs, y: ys})
loss_sum = loss_sum + loss
loss_list.append(loss)
x_values, y_values = shuffle(x_data, y_data)
b0temp=b.eval(session=sess)
w0temp=w.eval(session=sess)
loss_average = loss_sum/len(y_data)
print("epoch=", epoch+1, "loss=",loss_average,"b=",b0temp,"w=",w0temp)
#%%
plt.plot(loss_list)
#%%
n = np.random.randint(506)
print(n)
x_test = x_data[n]
x_test = x_test.reshape(1,12)
predict = sess.run(pred, feed_dict={x:x_test})
print("预测值: %f" % predict)
target = y_data[n]
print("标签值: %f" % target)
#%%
训练集在此
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO LSTAT MEDV
0.00632 18 2.31 0 0.538 6.575 65.2 4.09 1 296 15.3 4.98 24
0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 9.14 21.6
0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 4.03 34.7
0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 2.94 33.4
0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 5.33 36.2
0.02985 0 2.18 0 0.458 6.43 58.7 6.0622 3 222 18.7 5.21 28.7
0.08829 12.5 7.87 0 0.524 6.012 66.6 5.5605 5 311 15.2 12.43 22.9
0.14455 12.5 7.87 0 0.524 6.172 96.1 5.9505 5 311 15.2 19.15 27.1
0.21124 12.5 7.87 0 0.524 5.631 100 6.0821 5 311 15.2 29.93 16.5
0.17004 12.5 7.87 0 0.524 6.004 85.9 6.5921 5 311 15.2 17.1 18.9
0.22489 12.5 7.87 0 0.524 6.377 94.3 6.3467 5 311 15.2 20.45 15
0.11747 12.5 7.87 0 0.524 6.009 82.9 6.2267 5 311 15.2 13.27 18.9
0.09378 12.5 7.87 0 0.524 5.889 39 5.4509 5 311 15.2 15.71 21.7
0.62976 0 8.14 0 0.538 5.949 61.8 4.7075 4 307 21 8.26 20.4
0.63796 0 8.14 0 0.538 6.096 84.5 4.4619 4 307 21 10.26 18.2
0.62739 0 8.14 0 0.538 5.834 56.5 4.4986 4 307 21 8.47 19.9
1.05393 0 8.14 0 0.538 5.935 29.3 4.4986 4 307 21 6.58 23.1
0.7842 0 8.14 0 0.538 5.99 81.7 4.2579 4 307 21 14.67 17.5
0.80271 0 8.14 0 0.538 5.456 36.6 3.7965 4 307 21 11.69 20.2
0.7258 0 8.14 0 0.538 5.727 69.5 3.7965 4 307 21 11.28 18.2
1.25179 0 8.14 0 0.538 5.57 98.1 3.7979 4 307 21 21.02 13.6
0.85204 0 8.14 0 0.538 5.965 89.2 4.0123 4 307 21 13.83 19.6
1.23247 0 8.14 0 0.538 6.142 91.7 3.9769 4 307 21 18.72 15.2
0.98843 0 8.14 0 0.538 5.813 100 4.0952 4 307 21 19.88 14.5
0.75026 0 8.14 0 0.538 5.924 94.1 4.3996 4 307 21 16.3 15.6
0.84054 0 8.14 0 0.538 5.599 85.7 4.4546 4 307 21 16.51 13.9
0.67191 0 8.14 0 0.538 5.813 90.3 4.682 4 307 21 14.81 16.6
0.95577 0 8.14 0 0.538 6.047 88.8 4.4534 4 307 21 17.28 14.8
0.77299 0 8.14 0 0.538 6.495 94.4 4.4547 4 307 21 12.8 18.4
1.00245 0 8.14 0 0.538 6.674 87.3 4.239 4 307 21 11.98 21
1.13081 0 8.14 0 0.538 5.713 94.1 4.233 4 307 21 22.6 12.7
1.35472 0 8.14 0 0.538 6.072 100 4.175 4 307 21 13.04 14.5
1.38799 0 8.14 0 0.538 5.95 82 3.99 4 307 21 27.71 13.2
1.15172 0 8.14 0 0.538 5.701 95 3.7872 4 307 21 18.35 13.1
1.61282 0 8.14 0 0.538 6.096 96.9 3.7598 4 307 21 20.34 13.5
0.06417 0 5.96 0 0.499 5.933 68.2 3.3603 5 279 19.2 9.68 18.9
0.09744 0 5.96 0 0.499 5.841 61.4 3.3779 5 279 19.2 11.41 20
0.08014 0 5.96 0 0.499 5.85 41.5 3.9342 5 279 19.2 8.77 21
0.17505 0 5.96 0 0.499 5.966 30.2 3.8473 5 279 19.2 10.13 24.7
0.02763 75 2.95 0 0.428 6.595 21.8 5.4011 3 252 18.3 4.32 30.8
0.03359 75 2.95 0 0.428 7.024 15.8 5.4011 3 252 18.3 1.98 34.9
0.12744 0 6.91 0 0.448 6.77 2.9 5.7209 3 233 17.9 4.84 26.6
0.1415 0 6.91 0 0.448 6.169 6.6 5.7209 3 233 17.9 5.81 25.3
0.15936 0 6.91 0 0.448 6.211 6.5 5.7209 3 233 17.9 7.44 24.7
0.12269 0 6.91 0 0.448 6.069 40 5.7209 3 233 17.9 9.55 21.2
0.17142 0 6.91 0 0.448 5.682 33.8 5.1004 3 233 17.9 10.21 19.3
0.18836 0 6.91 0 0.448 5.786 33.3 5.1004 3 233 17.9 14.15 20
0.22927 0 6.91 0 0.448 6.03 85.5 5.6894 3 233 17.9 18.8 16.6
0.25387 0 6.91 0 0.448 5.399 95.3 5.87 3 233 17.9 30.81 14.4
0.21977 0 6.91 0 0.448 5.602 62 6.0877 3 233 17.9 16.2 19.4
0.08873 21 5.64 0 0.439 5.963 45.7 6.8147 4 243 16.8 13.45 19.7
0.04337 21 5.64 0 0.439 6.115 63 6.8147 4 243 16.8 9.43 20.5
0.0536 21 5.64 0 0.439 6.511 21.1 6.8147 4 243 16.8 5.28 25
0.04981 21 5.64 0 0.439 5.998 21.4 6.8147 4 243 16.8 8.43 23.4
0.0136 75 4 0 0.41 5.888 47.6 7.3197 3 469 21.1 14.8 18.9
0.01311 90 1.22 0 0.403 7.249 21.9 8.6966 5 226 17.9 4.81 35.4
0.02055 85 0.74 0 0.41 6.383 35.7 9.1876 2 313 17.3 5.77 24.7
0.01432 100 1.32 0 0.411 6.816 40.5 8.3248 5 256 15.1 3.95 31.6
0.15445 25 5.13 0 0.453 6.145 29.2 7.8148 8 284 19.7 6.86 23.3
0.10328 25 5.13 0 0.453 5.927 47.2 6.932 8 284 19.7 9.22 19.6
0.14932 25 5.13 0 0.453 5.741 66.2 7.2254 8 284 19.7 13.15 18.7
0.17171 25 5.13 0 0.453 5.966 93.4 6.8185 8 284 19.7 14.44 16
0.11027 25 5.13 0 0.453 6.456 67.8 7.2255 8 284 19.7 6.73 22.2
0.1265 25 5.13 0 0.453 6.762 43.4 7.9809 8 284 19.7 9.5 25
0.01951 17.5 1.38 0 0.4161 7.104 59.5 9.2229 3 216 18.6 8.05 33
0.03584 80 3.37 0 0.398 6.29 17.8 6.6115 4 337 16.1 4.67 23.5
0.04379 80 3.37 0 0.398 5.787 31.1 6.6115 4 337 16.1 10.24 19.4
0.05789 12.5 6.07 0 0.409 5.878 21.4 6.498 4 345 18.9 8.1 22
0.13554 12.5 6.07 0 0.409 5.594 36.8 6.498 4 345 18.9 13.09 17.4
0.12816 12.5 6.07 0 0.409 5