【Python深度学习】Python全栈体系(三十四)

深度学习

第十五章 数据准备

一、数据准备

1. 什么是数据准备?
  • 数据准备是指将样本数据从外部(主要指文件)读入,并且按照一定方式(随机、批量)传递给神经网络,进行训练或测试的过程
  • 数据准备包含三个步骤:
    • 第一步:自定义Reader生成训练、预测数据
    • 第二步:在网络配置中定义数据层变量
    • 第三步:将数据送入网络进行训练/预测
2. 为什么需要数据准备?
  • 从文件读入数据。因为程序无法保存大量数据,数据一般保存到文件中,所以需要单独的数据读取操作
  • 批量快速读入。深度学习样本数据量较大,需要快速、高效读取(批量读取模式)
  • 随机读入。为了提高模型泛化能力,有时需要随机读取数据(随机读取模式)
3. 代码
import paddle


# 原始读取器
def reader_creator(file_path):
    def reader():
        with open(file_path, "r") as f:  # 打开文件
            lines = f.readlines()  # 读取所有行
            for line in lines:
                yield line.replace("\n", "")  # 利用生成器关键字创建一个数据并返回

    return reader


reader = reader_creator("test.txt")  # 原始顺序读取器
shuffle_reader = paddle.reader.shuffle(reader, 10)  # 随机读取器
batch_reader = paddle.batch(shuffle_reader, 3)  # 批量随机读取器

# for data in reader():  # 迭代
# for data in shuffle_reader():  # 对随机读取器进行迭代
for data in batch_reader():  # 对批量随机读取器进行迭代
    print(data, end="")

"""
['888888888888,8', '111111111111,1', '444444444444,4']['333333333333,3', '666666666666,6', '222222222222,2']['000000000000,0', '999999999999,9', '555555555555,5']['777777777777,7']
"""

二、模型保存与加载

1. 预测模型保存与加载
  • 保存预测模型:
    • fluid.io.save_inference_model(dirname, feeded_var_names, target_vars, executor)
    • 参数说明:
      • dirname(str):保存预测model的路径
      • feeded_var_names(list[str]):预测需要feed的数据
      • target_vars(list[Variable]):保存预测结果的Variables
      • executor(Executor):executor保存inference model
  • 加载预测模型:
    • fluid.io.load_inference_model(dirname, executor)
    • 参数说明:
      • dirname(str):保存预测model的路径
      • executor(Executor):运行模型的Executor
    • 返回值说明:
      • Program:用于预测的Program
      • feed_target_names(str列表):预测Program中提供数据的变量的名称
      • fetch_targets(Variable列表):存放预测结果
2. 增量模型保存与加载
  • 保存增量模型:
    • fluid.io.save_persistables(executor, dirname, main_program=None)
    • 参数说明:
      • executor(Executor):保存变量的executor
      • dirname(str):保存模型的路径
      • main_program(Program|None):需要保存变量的Program。如果为None,则使用default_main_Program
3. fluid API结构图

在这里插入图片描述

第十六章 综合案例:波士顿房价预测

任务介绍

1. 数据集及任务
  • 数据集介绍
    • 数据量:506笔
    • 特征数量:13个(见下图)
    • 标签:价格中位数
  • 任务:根据样本数据,预测房价中位数(回归问题)
    在这里插入图片描述
2. 思路

在这里插入图片描述

3. 代码
# 波士顿房价预测案例(多元回归)
"""
数据集:包含506笔房价数据,每笔数据13个特征、1个标签
"""
import os.path

import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt

# 第一步:数据准备
# 缓冲区
BUF_SIZE = 500
# 批次大小
BATCH_SIZE = 20

random_reader = paddle.reader.shuffle(paddle.dataset.uci_housing.train(),  # 训练集reader
                                      buf_size=BUF_SIZE)
train_reader = paddle.batch(random_reader, batch_size=BATCH_SIZE)  # 批量读取器
# # 打印数据
# train_data = paddle.dataset.uci_housing.train()
# for sample in train_data():
#     print(sample)
"""
# 13个特征
# 标签:房屋的价格中位数
(array([ 0.23814999, -0.11363636,  0.25525005, -0.06916996,  0.28457807,
       -0.17927465,  0.2824418 , -0.1902575 ,  0.62828665,  0.49191383,
        0.18558153,  0.10143217,  0.19638346]), array([8.3]))
"""
# 第二步:模型搭建
x = fluid.layers.data(name="x", shape=[13], dtype="float32")
y = fluid.layers.data(name="y", shape=[1], dtype="float32")
# 定义全连接模型
y_predict = fluid.layers.fc(input=x,  # 输入
                            size=1,  # 输出值的个数
                            act=None)  # 激活函数
# 损失函数
cost = fluid.layers.square_error_cost(input=y_predict,  # 预测值
                                      label=y)  # 真实值
avg_cost = fluid.layers.mean(cost)  # 均方差
# 优化器
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(avg_cost)  # 指定优化的目标函数
# 第三步:模型训练、保存
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# feeder:参数喂入器,能对参数格式转换,转为模型所需要的张量格式
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
iter = 0
iters = []
train_costs = []
EPOCHE_NUM = 120
model_save_dir = "model/uci_housing"  # 模型保存路径

for pass_id in range(EPOCHE_NUM):
    train_cost = 0
    i = 0
    for data in train_reader():
        i += 1
        train_cost = exe.run(program=fluid.default_main_program(),
                             feed=feeder.feed(data),
                             fetch_list=[avg_cost])
        if i % 20 == 0:
            print("pass_id:%d, cost:%f" % (pass_id, train_cost[0][0]))
        iter = iter + BATCH_SIZE
        iters.append(iter)  # 记录训练次数
        train_costs.append(train_cost[0][0])  # 记录损失值

# 保存模型
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)
fluid.io.save_inference_model(model_save_dir,  # 模型保存路径
                              ["x"],  # 预测时需要喂入的参数
                              [y_predict],  # 模型预测的结果从哪里获取
                              exe)  # 模型

# 训练过程可视化
plt.figure("Training Cost")
plt.title("Training Cost", fontsize=24)
plt.xlabel("iter", fontsize=14)
plt.ylabel("cost", fontsize=14)
plt.plot(iters, train_costs, color="red", label="Training Cost")
plt.grid()
plt.savefig("train.png")
# 第四步:模型加载、预测
infer_exe = fluid.Executor(place)
infer_result = []  # 预测值列表
ground_truths = []  # 真实值列表

# 加载模型
infer_program, feed_target_names, fetch_targets = \
    fluid.io.load_inference_model(model_save_dir,  # 模型保存路径
                                  infer_exe)  # 要加载到哪个执行器上

# 测试集读取reader
infer_reader = paddle.batch(paddle.dataset.uci_housing.test(), # 读取测试集
                            batch_size=200)
test_data = next(infer_reader()) # 获取一批数据
test_x = np.array([data[0] for data in test_data]).astype("float32")
test_y = np.array([data[1] for data in test_data]).astype("float32")
# 构建参数字典
x_name = feed_target_names[0] # 获取参数名称
results = infer_exe.run(infer_program, # 执行预测的program
                        feed={x_name: test_x}, # 参数
                        fetch_list=fetch_targets) # 获取预测结果
# 预测值列表
for idx, val in enumerate(results[0]):
    print("%d: %f" % (idx, val))
    infer_result.append(val)

# 真实值列表
for idx, val in enumerate(test_y):
    print("%d: %f" % (idx, val))
    ground_truths.append(val)

# 将预测结果可视化
plt.figure("infer")
plt.title("infer", fontsize=24)
plt.xlabel("ground truth", fontsize=14)
plt.ylabel("infer result", fontsize=14)
x = np.arange(1, 30)
y = x
plt.plot(x, y) # 绘制y=x斜线
plt.scatter(ground_truths, infer_result, color="green", label="infer")
plt.grid()
plt.legend()
plt.savefig("predict.png")
plt.show()
4. 执行结果
pass_id:0, cost:656.089905
pass_id:1, cost:451.363678
pass_id:2, cost:584.316589
pass_id:3, cost:477.108948
pass_id:4, cost:286.599640
pass_id:5, cost:445.707367
pass_id:6, cost:446.433838
pass_id:7, cost:335.848511
pass_id:8, cost:273.062225
pass_id:9, cost:255.784912
pass_id:10, cost:278.432373
pass_id:11, cost:276.121887
pass_id:12, cost:175.726196
pass_id:13, cost:170.238754
pass_id:14, cost:170.852570
pass_id:15, cost:223.544922
pass_id:16, cost:166.904495
pass_id:17, cost:321.751526
pass_id:18, cost:280.356567
pass_id:19, cost:111.091576
pass_id:20, cost:124.681442
pass_id:21, cost:64.695580
pass_id:22, cost:129.477448
pass_id:23, cost:133.440948
pass_id:24, cost:130.348145
pass_id:25, cost:102.667458
pass_id:26, cost:64.281265
pass_id:27, cost:142.222763
pass_id:28, cost:26.178593
pass_id:29, cost:220.263596
pass_id:30, cost:169.756500
pass_id:31, cost:119.223656
pass_id:32, cost:87.624367
pass_id:33, cost:59.109009
pass_id:34, cost:164.397720
pass_id:35, cost:98.710800
pass_id:36, cost:117.974304
pass_id:37, cost:64.506653
pass_id:38, cost:104.113625
pass_id:39, cost:78.288124
pass_id:40, cost:103.716820
pass_id:41, cost:95.082603
pass_id:42, cost:34.526741
pass_id:43, cost:99.486519
pass_id:44, cost:122.406921
pass_id:45, cost:176.348862
pass_id:46, cost:56.122032
pass_id:47, cost:48.959282
pass_id:48, cost:114.838989
pass_id:49, cost:173.656082
pass_id:50, cost:96.353210
pass_id:51, cost:129.643478
pass_id:52, cost:78.118484
pass_id:53, cost:56.693672
pass_id:54, cost:67.857742
pass_id:55, cost:15.136653
pass_id:56, cost:87.636497
pass_id:57, cost:45.029011
pass_id:58, cost:108.201218
pass_id:59, cost:32.179466
pass_id:60, cost:34.872448
pass_id:61, cost:94.557373
pass_id:62, cost:127.176132
pass_id:63, cost:81.021133
pass_id:64, cost:27.862711
pass_id:65, cost:75.477615
pass_id:66, cost:119.252541
pass_id:67, cost:93.257736
pass_id:68, cost:25.911819
pass_id:69, cost:17.109428
pass_id:70, cost:35.836407
pass_id:71, cost:69.057404
pass_id:72, cost:112.613510
pass_id:73, cost:68.981125
pass_id:74, cost:49.957832
pass_id:75, cost:20.481647
pass_id:76, cost:59.729126
pass_id:77, cost:45.460415
pass_id:78, cost:22.951813
pass_id:79, cost:40.394081
pass_id:80, cost:37.409126
pass_id:81, cost:41.443184
pass_id:82, cost:70.590271
pass_id:83, cost:54.799217
pass_id:84, cost:41.712090
pass_id:85, cost:79.634201
pass_id:86, cost:103.184982
pass_id:87, cost:15.930639
pass_id:88, cost:97.250771
pass_id:89, cost:43.428303
pass_id:90, cost:104.876076
pass_id:91, cost:71.580521
pass_id:92, cost:38.239330
pass_id:93, cost:16.533834
pass_id:94, cost:43.827812
pass_id:95, cost:16.911013
pass_id:96, cost:66.245995
pass_id:97, cost:45.150234
pass_id:98, cost:13.511981
pass_id:99, cost:41.205372
pass_id:100, cost:17.888485
pass_id:101, cost:51.672241
pass_id:102, cost:54.815704
pass_id:103, cost:20.194555
pass_id:104, cost:110.166306
pass_id:105, cost:53.912636
pass_id:106, cost:26.374447
pass_id:107, cost:14.297429
pass_id:108, cost:23.325668
pass_id:109, cost:60.575584
pass_id:110, cost:46.281517
pass_id:111, cost:162.359894
pass_id:112, cost:46.856792
pass_id:113, cost:101.333237
pass_id:114, cost:45.367104
pass_id:115, cost:36.769276
pass_id:116, cost:31.477345
pass_id:117, cost:59.371132
pass_id:118, cost:19.479343
pass_id:119, cost:45.612984
0: 14.273354
1: 14.844604
2: 13.930457
3: 16.414011
4: 14.629221
5: 15.708012
6: 15.217474
7: 14.663005
8: 11.310017
9: 14.527328
10: 10.711575
11: 13.088373
12: 13.984264
13: 13.269247
14: 13.599314
15: 14.719706
16: 16.353481
17: 16.084949
18: 16.393032
19: 14.104583
20: 14.957327
21: 13.355427
22: 15.661732
23: 15.234718
24: 14.793789
25: 14.038326
26: 15.540333
27: 15.426285
28: 16.688440
29: 15.434376
30: 15.232512
31: 14.383001
32: 14.589976
33: 12.967784
34: 12.295956
35: 15.128961
36: 15.319921
37: 16.048567
38: 16.329090
39: 16.142662
40: 14.333231
41: 13.827244
42: 15.941467
43: 16.392628
44: 16.204958
45: 15.768801
46: 14.903126
47: 16.429163
48: 16.509726
49: 17.171146
50: 14.705102
51: 15.027884
52: 14.286346
53: 14.653322
54: 16.267881
55: 16.920368
56: 16.329792
57: 17.043941
58: 17.219616
59: 17.729343
60: 17.799040
61: 17.409719
62: 14.858363
63: 15.860466
64: 16.875864
65: 17.623283
66: 17.241474
67: 17.786171
68: 17.867687
69: 18.561293
70: 15.950287
71: 15.370593
72: 16.801212
73: 14.720805
74: 16.501842
75: 17.383152
76: 18.622833
77: 19.182955
78: 19.481457
79: 18.825390
80: 18.174511
81: 18.768740
82: 17.449657
83: 18.290375
84: 17.046276
85: 15.830663
86: 14.686304
87: 17.283913
88: 18.288227
89: 22.529139
90: 22.696047
91: 22.225986
92: 20.664625
93: 21.986166
94: 22.403265
95: 21.603868
96: 21.927729
97: 23.351410
98: 22.955769
99: 23.764650
100: 23.540255
101: 22.982683
0: 8.500000
1: 5.000000
2: 11.900000
3: 27.900000
4: 17.200001
5: 27.500000
6: 15.000000
7: 17.200001
8: 17.900000
9: 16.299999
10: 7.000000
11: 7.200000
12: 7.500000
13: 10.400000
14: 8.800000
15: 8.400000
16: 16.700001
17: 14.200000
18: 20.799999
19: 13.400000
20: 11.700000
21: 8.300000
22: 10.200000
23: 10.900000
24: 11.000000
25: 9.500000
26: 14.500000
27: 14.100000
28: 16.100000
29: 14.300000
30: 11.700000
31: 13.400000
32: 9.600000
33: 8.700000
34: 8.400000
35: 12.800000
36: 10.500000
37: 17.100000
38: 18.400000
39: 15.400000
40: 10.800000
41: 11.800000
42: 14.900000
43: 12.600000
44: 14.100000
45: 13.000000
46: 13.400000
47: 15.200000
48: 16.100000
49: 17.799999
50: 14.900000
51: 14.100000
52: 12.700000
53: 13.500000
54: 14.900000
55: 20.000000
56: 16.400000
57: 17.700001
58: 19.500000
59: 20.200001
60: 21.400000
61: 19.900000
62: 19.000000
63: 19.100000
64: 19.100000
65: 20.100000
66: 19.900000
67: 19.600000
68: 23.200001
69: 29.799999
70: 13.800000
71: 13.300000
72: 16.700001
73: 12.000000
74: 14.600000
75: 21.400000
76: 23.000000
77: 23.700001
78: 25.000000
79: 21.799999
80: 20.600000
81: 21.200001
82: 19.100000
83: 20.600000
84: 15.200000
85: 7.000000
86: 8.100000
87: 13.600000
88: 20.100000
89: 21.799999
90: 24.500000
91: 23.100000
92: 19.700001
93: 18.299999
94: 21.200001
95: 17.500000
96: 16.799999
97: 22.400000
98: 20.600000
99: 23.900000
100: 22.000000
101: 11.900000

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柠檬小帽

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值