房价预测问题的tensorflow库直接实现,本来想输出r2,但是不太清楚怎么使用tensorflow直接输出r2,所以就用了loss
import tensorflow as tf
tf = tf.compat.v1
tf.disable_v2_behavior()
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
import numpy as np
n_inputs=8
n_output=1
n_hidden1=100
n_hidden2=50
def process_feature(X):
scaler=StandardScaler()
X=scaler.fit_transform(X)
scaler=MinMaxScaler(feature_range=(-1,1))
X=scaler.fit_transform(X)
return X
X=tf.placeholder(tf.float32,shape=(None,n_inputs))
y=tf.placeholder(tf.float32,shape=None)
hidden1=tf.layers.dense(X,n_hidden1,activation=tf.nn.relu)
hidden2=tf.layers.dense(hidden1,n_hidden2,activation=tf.nn.relu)
output=tf.layers.dense(hidden2,n_output)
cross_entry=tf.square(y-output)
loss=tf.reduce_mean(cross_entry)
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.1)
train_op=optimizer.minimize(loss)
correct=tf.equal(tf.argmax(y,1),tf.argmax(output,1))
accuracy_score=tf.reduce_mean(tf.cast(correct,tf.float32))
lode_data=fetch_california_housing()
X_lode = lode_data.data
y_lode = lode_data.target.reshape(-1, 1)
with tf.Session() as sess:
tf.global_variables_initializer().run()
X_train,X_test,y_train,y_test=train_test_split(X_lode,y_lode,test_size=0.5,random_state=0)
X_train=np.array(process_feature(X_train))
X_test=np.array(process_feature(X_test))
for step in range(5000):
i=np.random.randint(0,len(X_train))
X_i=X_train[i].reshape(1,-1)
y_i=y_train[i].reshape(1,-1)
sess.run(train_op,feed_dict={X:X_i,y:y_i})
loss_=sess.run(loss,feed_dict={X:X_i,y:y_i})
if step%100==0:
print("第{}个:".format(step))
print("loss={}".format(loss_))
输出结果为
第0个:
loss=1.931797742843628
第100个:
loss=0.0004440907505340874
第200个:
loss=0.1922372281551361
第300个:
loss=0.064860038459301
第400个:
loss=2.317430019378662
第500个:
loss=0.0021678160410374403
第600个:
loss=0.09645801782608032
第700个:
loss=0.008708802983164787
第800个:
loss=4.162382125854492
第900个:
loss=1.075676679611206
第1000个:
loss=0.017900949344038963
第1100个:
loss=0.0009775909129530191
第1200个:
loss=0.7814558744430542
第1300个:
loss=1.3715777397155762
第1400个:
loss=0.8856215476989746
第1500个:
loss=0.22965915501117706
第1600个:
loss=0.11970070749521255
第1700个:
loss=0.011234978213906288
第1800个:
loss=0.5727159380912781
第1900个:
loss=0.13193991780281067
第2000个:
loss=3.595919609069824
第2100个:
loss=0.4728265702724457
第2200个:
loss=0.8856393694877625
第2300个:
loss=0.014978619292378426
第2400个:
loss=0.3310540020465851
第2500个:
loss=0.011562751606106758
第2600个:
loss=0.001214497722685337
第2700个:
loss=0.9096536636352539
第2800个:
loss=0.0442768819630146
第2900个:
loss=0.4617774188518524
第3000个:
loss=0.18205884099006653
第3100个:
loss=0.019049761816859245
第3200个:
loss=0.09607062488794327
第3300个:
loss=3.223975419998169
第3400个:
loss=0.0025411953683942556
第3500个:
loss=0.7368151545524597
第3600个:
loss=0.7690043449401855
第3700个:
loss=0.22527238726615906
第3800个:
loss=7.095444202423096
第3900个:
loss=0.9337527751922607
第4000个:
loss=0.17915597558021545
第4100个:
loss=0.8352121710777283
第4200个:
loss=0.05043191462755203
第4300个:
loss=1.4584596157073975
第4400个:
loss=0.5663894414901733
第4500个:
loss=0.8200605511665344
第4600个:
loss=0.22166895866394043
第4700个:
loss=6.3242340087890625
第4800个:
loss=0.8125528693199158
第4900个:
loss=1.1417784690856934