dnn回归预测_Tensorflow Python:完成的训练DNN回归模型,执行多次后预测会发生巨大变化...

本文档展示了如何使用TensorFlow训练深度神经网络(DNN)回归模型,并探讨了在不同参数设置下,多次运行预测结果的变异性。通过训练和验证过程中的均方根误差(RMSE)变化,分析模型的稳定性。
摘要由CSDN通过智能技术生成

deftrain_dnn_regression_model(learning_rate,regularization_strength,steps,batch_size,hidden_units,feature_columns,training_examples,training_targets,validation_examples,validation_targets,):periods=10steps_per_period=steps/periods#Initialize DNN Regressoroptimizer=tf.train.FtrlOptimizer(learning_rate=learning_rate,l1_regularization_strength=regularization_strength)optimizer=tf.contrib.estimator.clip_gradients_by_norm(optimizer,5.0)dnn_regressor=tf.estimator.DNNRegressor(feature_columns=feature_columns,hidden_units=hidden_units,optimizer=optimizer,activation_fn=tf.nn.leaky_relu)#Training Functionstraining_input_fn=lambda:input_fn(training_examples,training_targets,batch_size=batch_size)predict_training_input_fn=lambda:input_fn(training_examples,training_targets,num_epochs=1,shuffle=False)#Validation Functionpredict_validation_input_fn=lambda:input_fn(validation_examples,validation_targets,num_epochs=1,shuffle=False)#Train Modeltraining_rmse=[]validation_rmse=[]print("Training Model")forperiodinrange(0,periods):linear_regressor.train(input_fn=training_input_fn,#Manually break total steps by 10steps=steps_per_period)#Use Sklearn to calculate RMSEtraining_predictions=linear_regressor.predict(input_fn=predict_training_input_fn)training_predictions=np.array([item['predictions'][0]foritemintraining_predictions])training_root_mean_squared_error=math.sqrt(metrics.mean_squared_error(training_predictions,training_targets))#Calculate Validation RMSEvalidation_predictions=linear_regressor.predict(input_fn=predict_validation_input_fn)validation_predictions=np.array([item['predictions'][0]foriteminvalidation_predictions])validation_root_mean_squared_error=math.sqrt(metrics.mean_squared_error(validation_predictions,validation_targets))#Append Lossestraining_rmse.append(training_root_mean_squared_error)validation_rmse.append(validation_root_mean_squared_error)print("Period:",period,"RMSE:",training_root_mean_squared_error)print("Training Finished")#Graphplt.ylabel("RMSE")plt.xlabel("Periods")plt.title("Root Mean Squared Error vs. Periods")plt.tight_layout()plt.plot(training_rmse,label="training")plt.plot(validation_rmse,label="validation")plt.legend()returndnn_regressor

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值