一、代码中的数据集下载地址如下:
百度网盘提取码:lala
二、代码运行环境:
Tensorflow-gpu==2.4.0
Python==3.7
三、训练代码如下所示:
import tensorflow as tf
import os
import pandas as pd
import matplotlib. pyplot as plt
os. environ[ 'TF_XLA_FLAGS' ] = '--tf_xla_enable_xla_devices'
os. environ[ 'TF_FORCE_GPU_ALLOW_GROWTH' ] = 'true'
data = pd. read_csv( r'dataset/getter.csv' )
plt. scatter( data. Education, data. Income)
plt. show( )
x = data. Education
y = data. Income
model = tf. keras. Sequential( )
model. add( tf. keras. layers. Dense( 1 , input_shape= ( 1 , ) ) )
model. compile (
optimizer= 'adam' ,
loss= 'mse'
)
history = model. fit( x, y, epochs= 60000 , batch_size= 20 )
pre_y = model. predict( x)
pre_y = pre_y. flatten( )
plt. scatter( x, y)
plt. plot( x, pre_y, 'red' )
plt. show( )
model. save( r'model_data/model.h5' )
四、预测代码如下所示:
import tensorflow as tf
import os
import pandas as pd
import matplotlib. pyplot as plt
os. environ[ 'TF_XLA_FLAGS' ] = '--tf_xla_enable_xla_devices'
os. environ[ 'TF_FORCE_GPU_ALLOW_GROWTH' ] = 'true'
data = pd. read_csv( r'dataset/getter.csv' )
x = data. Education
y = data. Income
pre_model = tf. keras. models. load_model( r'model_data/model.h5' )
pre_y = pre_model. predict( x)
plt. scatter( x, y)
plt. plot( x, pre_y, 'red' )
plt. show( )
五、预测结果展示