梯度提升树python
This story demonstrates the implementation of a “gradient boosted tree regression” model using python & spark machine learning. The dataset used is “bike rental info” from 2011–2012 in the capital bike share system. Our goal is to predict the count of bike rentals.
这个故事展示了如何使用python和spark机器学习实现“梯度提升树回归”模型。 在首都自行车共享系统中,使用的数据集是2011-2012年的“ 自行车租赁信息 ”。 我们的目标是预测自行车租赁的数量 。
1.加载数据 (1. Load the data)
The data in store is a CSV file. We are to create a spark data frame containing the bike data set. We cache this data so that we read it only once from the disk.
存储中的数据是CSV文件。 我们将创建一个包含自行车数据集的spark数据框。 我们缓存此数据,以便仅从磁盘读取一次。
#load the dataset & cache
df = spark.read.csv("/databricks-datasets/bikeSharing/data-001/hour.csv", header="true", inferSchema="true")df.cache()df.cache()#view the imported dataset
display(df)
输出: (Output:)
2.预处理数据 (2. Pre-Process the data)
Fields such as “weekday” are indexed, and all the other fields except date “dteday” are numerical. The count is our target "label". The “cnt” column we aim to predict equals the sum of the “casual” & “registered” columns.
索引诸如“工作日”的字段,除日期“ dteday”以外的所有其他字段均为数字。 计数是我们的目标“标签”。 我们旨在预测的“ cnt”列等于“休闲”和“注册”列的总和。
The next steps involve removing the “casual” and “registered” columns from the dataset to make sure we do not use them in predicting “cnt”. So, we discard the “dteday” and use the columns “season”, “yr”, “mnth” and “weekday”.
下一步涉及从数据集中删除“休闲”和“注册”列,以确保我们在预测“ cnt”时不使用它们。 因此,我们丢弃“ dteday”并使用“ season”,“ yr”,“ mnth”和“ weekday”列。
#drop the features mentioned
df = df.drop("instant").drop("dteday").drop("casual").drop("registered")#print the schema of our dataset to see t