文章目录
本文基于百度飞浆Paddle平台
项目地址
:
用PaddlePaddle做房价预测
波士顿房价预测
- 框架: 百度飞浆Paddle
- 日期: 2021-11-14
- 类型: 结构化数据
1. 数据描述
波士顿数据框有 506 行和 14 列
对应特征
:
-
crim:犯罪率
-
zn:划分为超过25,000平方英尺地段的住宅用地所占比例
-
indus:每镇非零售商铺面积比例
-
chas:是否临河
-
nox:氮氧化物浓度(千万分之一)
-
rm: 每个住宅的平均房间数
-
age:一九四年以前业主自住单位比例
-
dis:波士顿五个商业中心的加权平均距离
-
rad:放射状公路的可达性指数
-
tax:每$10,000的全价值物业税税率
-
ptratio: 学生-教师比例按城镇划分
-
black:1000(Bk - 0.63)^2其中Bk是按城镇划分的黑人比例
-
lstat:低收入阶层人口占比
-
medv:自住房屋价值中位数,以千元计
2.导入数据以及数据预处理
# 导入波士顿房价数据
import os
import paddle
import numpy as np
# 设置训练Batch大小
BATCH_SIZE = 20
# 训练集
train_datasets = paddle.text.datasets.UCIHousing(mode= 'train')
# 验证集
valid_datasets = paddle.text.datasets.UCIHousing(mode= 'test')
# 用于训练的额数据集加载器,每次随机读取batch大小的数据,剩余不足的批次大小的数据将被丢弃
train_loader = paddle.io.DataLoader(train_datasets, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
# 测试集加载器,每次读取随机批次大小的数据
valid_loader = paddle.io.DataLoader(valid_datasets, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
# 打印数据类型
print(type(train_datasets))
<class 'paddle.text.datasets.uci_housing.UCIHousing'>
# 打印查看uci_housing数据
print(train_datasets[0])
# 每一行是一个样本,每个样本有14个特征
# print(train_datasets)
(array([-0.0405441 , 0.06636363, -0.32356226, -0.06916996, -0.03435197,
0.05563625, -0.03475696, 0.02682186, -0.37171334, -0.21419305,
-0.33569506, 0.10143217, -0.21172912], dtype=float32), array([24.], dtype=float32))
3. 构建学习网络
# 定义网络结构
net = paddle.nn.Linear(13, 1)
# 定义优化函数
optimizer = paddle.optimizer.SGD(learning_rate = 0.001, parameters = net.parameters())
4. 模型的训练与评估
4.1 定义画图方法
# 定义绘图函数
import matplotlib.pyplot as plt
iter = 0
iters = []
train_costs = []
def draw_train_process(iters, train_costs):
title = 'training costs'
plt.title(title, fontsize = 24)
plt.xlabel('iter', fontsize = 14)
plt.ylabel('cost', fontsize = 14)
plt.plot(iters, train_costs, color = 'red', label = 'training cost')
plt.grid()
plt.show()
4.2 训练并保存模型
# 定义训练轮次
EPOCH_NUM = 50
# 训练EPOCH_NUM轮
for pass_id in range(EPOCH_NUM):
# 开始训练并输出最后一个batch的损失值
train_cost = 0
# 遍历train_loader迭代器
for batch_id, data in enumerate(train_loader()):
# 分别提取训练集和标签
inputs = paddle.to_tensor(data[0])
labels = paddle.to_tensor(data[1])
# 计算输出
out = net(inputs)
# 计算损失函数(均方差)
train_loss = paddle.mean(paddle.nn.functional.square_error_cost(out, labels))
# 反向迭代
train_loss.backward()
# 优化并清空dw
optimizer.step()
optimizer.clear_grad()
# 每40步输出信息,
# 从0batch开始:0, 40, 80
if batch_id % 40 == 0:
print("Pass id: %d, cost: %0.5f" % (pass_id, train_loss))
iter = iter + BATCH_SIZE
iters.append(iter)
train_costs.append(train_loss.numpy()[0])
# 开始测试并输出最后一个batch的缺失值
test_loss = 0
# 遍历test_reader迭代器
for batch_id, data in enumerate(valid_loader()):
# 分别提取训练集和标签
inputs = paddle.to_tensor(data[0])
labels = paddle.to_tensor(data[1])
# 计算输出
out = net(inputs)
# 计算损失函数(均方差)
train_loss = paddle.mean(paddle.nn.functional.square_error_cost(out, labels))
# 打印最后一个batch的损失值
print("Pass id: %d, cost: %0.5f" % (pass_id, train_loss))
# 保存模型
paddle.save(net.state_dict(), 'fit_a_line.pdparams')
draw_train_process(iters, train_costs)
Pass id: 0, cost: 697.82361
Pass id: 0, cost: 244.72691
Pass id: 1, cost: 552.90552
Pass id: 1, cost: 221.64818
Pass id: 2, cost: 429.48627
Pass id: 2, cost: 192.41925
Pass id: 3, cost: 484.38715
Pass id: 3, cost: 173.44687
Pass id: 4, cost: 387.63263
Pass id: 4, cost: 221.82217
Pass id: 5, cost: 402.83636
Pass id: 5, cost: 128.57277
Pass id: 6, cost: 547.67670
Pass id: 6, cost: 173.92017
Pass id: 7, cost: 308.60843
Pass id: 7, cost: 147.69878
Pass id: 8, cost: 219.78215
Pass id: 8, cost: 107.03541
Pass id: 9, cost: 208.08089
Pass id: 9, cost: 112.56085
Pass id: 10, cost: 311.31699
Pass id: 10, cost: 123.64721
Pass id: 11, cost: 363.99536
Pass id: 11, cost: 105.76099
Pass id: 12, cost: 420.03595
Pass id: 12, cost: 72.11946
Pass id: 13, cost: 327.58844
Pass id: 13, cost: 88.04731
Pass id: 14, cost: 390.53604
Pass id: 14, cost: 85.30426
Pass id: 15, cost: 230.41187
Pass id: 15, cost: 63.41622
Pass id: 16, cost: 166.12924
Pass id: 16, cost: 63.62786
Pass id: 17, cost: 236.01692
Pass id: 17, cost: 77.35274
Pass id: 18, cost: 290.79498
Pass id: 18, cost: 33.24523
Pass id: 19, cost: 97.46697
Pass id: 19, cost: 51.36712
Pass id: 20, cost: 340.30243
Pass id: 20, cost: 32.06834
Pass id: 21, cost: 168.25182
Pass id: 21, cost: 38.49080
Pass id: 22, cost: 175.69730
Pass id: 22, cost: 35.95297
Pass id: 23, cost: 205.52931
Pass id: 23, cost: 71.39227
Pass id: 24, cost: 207.87970
Pass id: 24, cost: 28.49346
Pass id: 25, cost: 116.07450
Pass id: 25, cost: 35.36184
Pass id: 26, cost: 182.95099
Pass id: 26, cost: 33.28909
Pass id: 27, cost: 181.40346
Pass id: 27, cost: 40.69643
Pass id: 28, cost: 53.31166
Pass id: 28, cost: 23.08777
Pass id: 29, cost: 159.42133
Pass id: 29, cost: 49.56973
Pass id: 30, cost: 148.92783
Pass id: 30, cost: 34.58585
Pass id: 31, cost: 98.45925
Pass id: 31, cost: 22.08068
Pass id: 32, cost: 64.72221
Pass id: 32, cost: 18.82707
Pass id: 33, cost: 39.10216
Pass id: 33, cost: 41.70573
Pass id: 34, cost: 86.37087
Pass id: 34, cost: 27.22211
Pass id: 35, cost: 30.80962
Pass id: 35, cost: 32.63113
Pass id: 36, cost: 112.65536
Pass id: 36, cost: 16.17883
Pass id: 37, cost: 69.60898
Pass id: 37, cost: 19.16402
Pass id: 38, cost: 39.25970
Pass id: 38, cost: 22.66752
Pass id: 39, cost: 19.43631
Pass id: 39, cost: 24.29722
Pass id: 40, cost: 118.90907
Pass id: 40, cost: 17.93452
Pass id: 41, cost: 63.60981
Pass id: 41, cost: 21.98477
Pass id: 42, cost: 48.37024
Pass id: 42, cost: 42.80917
Pass id: 43, cost: 120.80444
Pass id: 43, cost: 16.80975
Pass id: 44, cost: 67.58469
Pass id: 44, cost: 19.20858
Pass id: 45, cost: 19.36415
Pass id: 45, cost: 31.67330
Pass id: 46, cost: 67.76525
Pass id: 46, cost: 19.26246
Pass id: 47, cost: 91.39543
Pass id: 47, cost: 30.07190
Pass id: 48, cost: 120.14481
Pass id: 48, cost: 17.26798
Pass id: 49, cost: 89.36503
Pass id: 49, cost: 21.93978
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
5. 模型预测
5.1 模型的可视化真实值与预测值方法定义
infer_results = []
groud_truths = []
# 绘制真实值和预测值对比图
def draw_infer_result(groud_truths, infer_results):
title = 'Boston'
plt.title(title, fontsize = 24)
x = np.arange(1, 20)
y = x
plt.plot(x, y)
plt.xlabel('ground truth', fontsize = 14)
plt.ylabel('infer result', fontsize = 14)
plt.scatter(groud_truths, infer_results, color = 'green', label = 'training costs')
plt.legend()
plt.grid()
plt.show()
5.2 预测模型
import paddle
import numpy as np
import matplotlib.pyplot as plt
valid_datasets = paddle.text.UCIHousing(mode='test')
infer_loader = paddle.io.DataLoader(valid_datasets, batch_size=200)
# 先创建一个架构
infer_net = paddle.nn.Linear(13, 1)
# 加载模型
param = paddle.load('fit_a_line.pdparams')
infer_net.set_dict(param)
data = next(infer_loader())
inputs = paddle.to_tensor(data[0])
results = infer_net(inputs)
for idx, item in enumerate(zip(results, data[1])):
print("Index:%d, Infer Result: %.2f, Ground Truth: %.2f" % (idx, item[0], item[1]))
infer_results.append(item[0].numpy()[0])
groud_truths.append(item[1].numpy()[0])
draw_infer_result(groud_truths, infer_results)
Index:0, Infer Result: 12.86, Ground Truth: 8.50
Index:1, Infer Result: 12.75, Ground Truth: 5.00
Index:2, Infer Result: 12.67, Ground Truth: 11.90
Index:3, Infer Result: 14.19, Ground Truth: 27.90
Index:4, Infer Result: 13.26, Ground Truth: 17.20
Index:5, Infer Result: 13.88, Ground Truth: 27.50
Index:6, Infer Result: 13.41, Ground Truth: 15.00
Index:7, Infer Result: 13.40, Ground Truth: 17.20
Index:8, Infer Result: 11.47, Ground Truth: 17.90
Index:9, Infer Result: 13.08, Ground Truth: 16.30
Index:10, Infer Result: 10.83, Ground Truth: 7.00
Index:11, Infer Result: 12.48, Ground Truth: 7.20
Index:12, Infer Result: 13.14, Ground Truth: 7.50
Index:13, Infer Result: 12.54, Ground Truth: 10.40
Index:14, Infer Result: 12.28, Ground Truth: 8.80
Index:15, Infer Result: 13.67, Ground Truth: 8.40
Index:16, Infer Result: 14.19, Ground Truth: 16.70
Index:17, Infer Result: 14.12, Ground Truth: 14.20
Index:18, Infer Result: 14.37, Ground Truth: 20.80
Index:19, Infer Result: 13.34, Ground Truth: 13.40
Index:20, Infer Result: 13.96, Ground Truth: 11.70
Index:21, Infer Result: 12.72, Ground Truth: 8.30
Index:22, Infer Result: 14.42, Ground Truth: 10.20
Index:23, Infer Result: 13.74, Ground Truth: 10.90
Index:24, Infer Result: 13.72, Ground Truth: 11.00
Index:25, Infer Result: 13.13, Ground Truth: 9.50
Index:26, Infer Result: 14.08, Ground Truth: 14.50
Index:27, Infer Result: 13.92, Ground Truth: 14.10
Index:28, Infer Result: 14.86, Ground Truth: 16.10
Index:29, Infer Result: 14.02, Ground Truth: 14.30
Index:30, Infer Result: 13.76, Ground Truth: 11.70
Index:31, Infer Result: 13.28, Ground Truth: 13.40
Index:32, Infer Result: 13.44, Ground Truth: 9.60
Index:33, Infer Result: 12.45, Ground Truth: 8.70
Index:34, Infer Result: 12.15, Ground Truth: 8.40
Index:35, Infer Result: 13.51, Ground Truth: 12.80
Index:36, Infer Result: 13.52, Ground Truth: 10.50
Index:37, Infer Result: 14.01, Ground Truth: 17.10
Index:38, Infer Result: 14.17, Ground Truth: 18.40
Index:39, Infer Result: 14.03, Ground Truth: 15.40
Index:40, Infer Result: 13.11, Ground Truth: 10.80
Index:41, Infer Result: 13.01, Ground Truth: 11.80
Index:42, Infer Result: 14.04, Ground Truth: 14.90
Index:43, Infer Result: 14.22, Ground Truth: 12.60
Index:44, Infer Result: 14.10, Ground Truth: 14.10
Index:45, Infer Result: 13.93, Ground Truth: 13.00
Index:46, Infer Result: 13.73, Ground Truth: 13.40
Index:47, Infer Result: 14.30, Ground Truth: 15.20
Index:48, Infer Result: 14.39, Ground Truth: 16.10
Index:49, Infer Result: 14.67, Ground Truth: 17.80
Index:50, Infer Result: 13.58, Ground Truth: 14.90
Index:51, Infer Result: 13.85, Ground Truth: 14.10
Index:52, Infer Result: 13.45, Ground Truth: 12.70
Index:53, Infer Result: 13.72, Ground Truth: 13.50
Index:54, Infer Result: 14.41, Ground Truth: 14.90
Index:55, Infer Result: 14.69, Ground Truth: 20.00
Index:56, Infer Result: 14.41, Ground Truth: 16.40
Index:57, Infer Result: 14.75, Ground Truth: 17.70
Index:58, Infer Result: 14.88, Ground Truth: 19.50
Index:59, Infer Result: 15.11, Ground Truth: 20.20
Index:60, Infer Result: 15.39, Ground Truth: 21.40
Index:61, Infer Result: 15.43, Ground Truth: 19.90
Index:62, Infer Result: 13.83, Ground Truth: 19.00
Index:63, Infer Result: 14.07, Ground Truth: 19.10
Index:64, Infer Result: 14.76, Ground Truth: 19.10
Index:65, Infer Result: 15.33, Ground Truth: 20.10
Index:66, Infer Result: 14.93, Ground Truth: 19.90
Index:67, Infer Result: 15.19, Ground Truth: 19.60
Index:68, Infer Result: 15.38, Ground Truth: 23.20
Index:69, Infer Result: 15.82, Ground Truth: 29.80
Index:70, Infer Result: 14.06, Ground Truth: 13.80
Index:71, Infer Result: 13.74, Ground Truth: 13.30
Index:72, Infer Result: 14.54, Ground Truth: 16.70
Index:73, Infer Result: 13.26, Ground Truth: 12.00
Index:74, Infer Result: 14.30, Ground Truth: 14.60
Index:75, Infer Result: 14.83, Ground Truth: 21.40
Index:76, Infer Result: 15.92, Ground Truth: 23.00
Index:77, Infer Result: 16.14, Ground Truth: 23.70
Index:78, Infer Result: 16.29, Ground Truth: 25.00
Index:79, Infer Result: 16.34, Ground Truth: 21.80
Index:80, Infer Result: 15.95, Ground Truth: 20.60
Index:81, Infer Result: 16.18, Ground Truth: 21.20
Index:82, Infer Result: 15.11, Ground Truth: 19.10
Index:83, Infer Result: 15.84, Ground Truth: 20.60
Index:84, Infer Result: 15.65, Ground Truth: 15.20
Index:85, Infer Result: 14.94, Ground Truth: 7.00
Index:86, Infer Result: 14.31, Ground Truth: 8.10
Index:87, Infer Result: 15.73, Ground Truth: 13.60
Index:88, Infer Result: 16.46, Ground Truth: 20.10
Index:89, Infer Result: 20.31, Ground Truth: 21.80
Index:90, Infer Result: 20.51, Ground Truth: 24.50
Index:91, Infer Result: 20.40, Ground Truth: 23.10
Index:92, Infer Result: 19.09, Ground Truth: 19.70
Index:93, Infer Result: 19.87, Ground Truth: 18.30
Index:94, Infer Result: 20.13, Ground Truth: 21.20
Index:95, Infer Result: 19.60, Ground Truth: 17.50
Index:96, Infer Result: 19.72, Ground Truth: 16.80
Index:97, Infer Result: 21.10, Ground Truth: 22.40
Index:98, Infer Result: 20.79, Ground Truth: 20.60
Index:99, Infer Result: 21.10, Ground Truth: 23.90
Index:100, Infer Result: 21.01, Ground Truth: 22.00
Index:101, Infer Result: 20.78, Ground Truth: 11.90
写在最后
各位看官,都看到这里了,麻烦动动手指头给博主来个点赞8,您的支持作者最大的创作动力哟!
<(^-^)>
才疏学浅,若有纰漏,恳请斧正
本文章仅用于各位同志作为学习交流之用,不作任何商业用途,若涉及版权问题请速与作者联系,望悉知