Keras实现单变量线性回归;使用场景:根据工作小时得出报酬,使用的Anaconda 进行的操作
Anaconda下载地址
计算公式
其中x:代表工作显示数
f(x) :代表工作报仇
a和b:通过 *梯度下降算法* 计算出来的值;
梯度下降算法:是线性回归的核心算法
f(x)=xa+b
import tensorflow as tf
print("Tf V{}".format(tf.__version__))
Tf V2.4.1
#[pandas中文网站](https://www.pypandas.cn)
import pandas as pd
#引入图表库
import matplotlib.pyplot as plt
%matplotlib inline
data=pd.read_csv("./work_hours.csv")
data
word | money | |
---|---|---|
0 | 1 | 50 |
1 | 2 | 60 |
2 | 3 | 55 |
3 | 4 | 90 |
4 | 5 | 80 |
5 | 6 | 100 |
6 | 7 | 110 |
7 | 8 | 100 |
8 | 9 | 100 |
9 | 10 | 120 |
#生成线性图表
plt.scatter(data.word,data.money)
<matplotlib.collections.PathCollection at 0x21fb71726a0>
# 图中的 变量
x=data.word
y=data.money
#初始化 顺序模型
model=tf.keras.Sequential()
#添加层 Dense(维度,)
model.add(tf.keras.layers.Dense(1,input_shape=(1,)))
#显示模型层
model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 1) 2
=================================================================
Total params: 2
Trainable params: 2
Non-trainable params: 0
_________________________________________________________________
#编译、配置 optimizer:优化方法 名
#loss :损失值
model.compile(optimizer="adam",
loss="mse"
)
#训练 epochs:训练次数 (应该是训练次数越多 月稳定)
history=model.fit(x,y,epochs=50000)
# x:表里面的 word(工作小时数) 那一行的值
model.predict(x)
array([[ 53.363663],
[ 60.72731 ],
[ 68.09096 ],
[ 75.454605],
[ 82.81825 ],
[ 90.18191 ],
[ 97.545555],
[104.9092 ],
[112.27285 ],
[119.6365 ]], dtype=float32)
#假设工作4小时
model.predict(pd.Series([4]))
array([[75.454605]], dtype=float32)
#假设工作14小时
model.predict(pd.Series([14]))
array([[149.0911]], dtype=float32)
#假设工作8小时
model.predict(pd.Series([8]))
array([[104.9092]], dtype=float32)
#假设工作40小时
model.predict(pd.Series([40]))
array([[340.54596]], dtype=float32)
#假设工作24小时
model.predict(pd.Series([24]))
array([[222.72758]], dtype=float32)
#假设工作1小时
model.predict(pd.Series([1]))
array([[53.363663]], dtype=float32)
#假设工作3小时
model.predict(pd.Series([3]))
array([[68.09096]], dtype=float32)
使用到的文件格式
一定要使用的csv格式 的 文件,创建csv格式文件可以去百度搜索
文件下载地址