机器学习项目
(关键词:sklearn、pandas、流水线pipeline、线性回归模型、支持向量机模型、随机森林模型)
数据集来自:https://www.kaggle.com/c/sberbank-russian-housing-market
数据集背景介绍:俄罗斯联邦储蓄银行正在挑战Kagglers开发使用广泛特征预测房地产价格的算法。竞争对手将依赖一个丰富的数据集,其中包括住房数据和宏观经济模式。准确的预测模型将使俄罗斯联邦储蓄银行能够在不确定的经济中为其客户提供更多的确定性。
import os
import pandas as pd
HOUSE_PATH = './datasets'
def load_housing_data(housing_path=HOUSE_PATH):
csv_path = os.path.join(housing_path,'MoscowHouseTrain.csv')
return pd.read_csv(csv_path)
housing = load_housing_data()
housing.head()
id | timestamp | full_sq | life_sq | floor | max_floor | material | build_year | num_room | kitch_sq | ... | cafe_count_5000_price_2500 | cafe_count_5000_price_4000 | cafe_count_5000_price_high | big_church_count_5000 | church_count_5000 | mosque_count_5000 | leisure_count_5000 | sport_count_5000 | market_count_5000 | price_doc | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 2011-08-20 | 43 | 27.0 | 4.0 | NaN | NaN | NaN | NaN | NaN | ... | 9 | 4 | 0 | 13 | 22 | 1 | 0 | 52 | 4 | 5850000 |
1 | 2 | 2011-08-23 | 34 | 19.0 | 3.0 | NaN | NaN | NaN | NaN | NaN | ... | 15 | 3 | 0 | 15 | 29 | 1 | 10 | 66 | 14 | 6000000 |
2 | 3 | 2011-08-27 | 43 | 29.0 | 2.0 | NaN | NaN | NaN | NaN | NaN | ... | 10 | 3 | 0 | 11 | 27 | 0 | 4 | 67 | 10 | 5700000 |
3 | 4 | 2011-09-01 | 89 | 50.0 | 9.0 | NaN | NaN | NaN | NaN | NaN | ... | 11 | 2 | 1 | 4 | 4 | 0 | 0 | 26 | 3 | 13100000 |
4 | 5 | 2011-09-05 | 77 | 77.0 | 4.0 | NaN | NaN | NaN | NaN | NaN | ... | 319 | 108 | 17 | 135 | 236 | 2 | 91 | 195 | 14 | 16331452 |
5 rows × 292 columns
housingUse = housing[:10000]
housingUse.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Columns: 292 entries, id to price_doc
dtypes: float64(119), int64(157), object(16)
memory usage: 22.3+ MB
housingUse.describe()
id | full_sq | life_sq | floor | max_floor | material | build_year | num_room | kitch_sq | state | ... | cafe_count_5000_price_2500 | cafe_count_5000_price_4000 | cafe_count_5000_price_high | big_church_count_5000 | church_count_5000 | mosque_count_5000 | leisure_count_5000 | sport_count_5000 | market_count_5000 | price_doc | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 10000.00000 | 10000.000000 | 7752.000000 | 9834.000000 | 494.000000 | 494.000000 | 437.000000 | 494.000000 | 494.000000 | 398.000000 | ... | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 1.000000e+04 |
mean | 5002.24830 | 53.709900 | 34.181760 | 7.964308 | 11.921053 | 1.927126 | 1957.773455 | 1.935223 | 7.117409 | 2.419598 | ... | 30.614500 | 10.164100 | 1.657500 | 14.497400 | 29.372800 | 0.488500 | 8.144400 | 52.529400 | 5.991500 | 6.400306e+06 |
std | 2887.27684 | 58.306234 | 22.703808 | 5.383705 | 6.096573 | 1.472794 | 211.639811 | 0.774763 | 6.741713 | 0.752707 | ... | 70.407453 | 27.145192 | 5.193601 | 27.874316 | 45.477441 | 0.606716 | 19.782004 | 44.945444 | 4.763559 | 4.376049e+06 |
min | 1.00000 | 5.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.900000e+05 |
25% | 2502.75000 | 38.000000 | 20.000000 | 4.000000 | 8.000000 | 1.000000 | 1965.000000 | 1.000000 | 5.000000 | 2.000000 | ... | 2.000000 | 1.000000 | 0.000000 | 2.000000 | 8.000000 | 0.000000 | 0.000000 | 12.000000 | 1.000000 | 4.233045e+06 |
50% | 5002.50000 | 47.000000 | 30.000000 | 7.000000 | 12.000000 | 1.000000 | 1976.000000 | 2.000000 | 7.000000 | 2.500000 | ... | 8.000000 | 2.000000 | 0.000000 | 7.000000 | 16.000000 | 0.000000 | 2.000000 | 47.000000 | 5.000000 | 5.700000e+06 |
75% | 7502.25000 | 61.250000 | 42.000000 | 11.000000 | 17.000000 | 2.000000 | 1997.000000 | 2.000000 | 9.000000 | 3.000000 | ... | 20.000000 | 4.000000 | 1.000000 | 12.000000 | 28.000000 | 1.000000 | 6.000000 | 75.000000 | 10.000000 | 7.301552e+06 |
max | 10002.00000 | 5326.000000 | 802.000000 | 44.000000 | 40.000000 | 6.000000 | 2015.000000 | 5.000000 | 123.000000 | 4.000000 | ... | 377.000000 | 147.000000 | 30.000000 | 151.000000 | 250.000000 | 2.000000 | 105.000000 | 218.000000 | 20.000000 | 1.111111e+08 |
8 rows × 276 columns
%matplotlib inline
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(60, 50))
housingUse.hist(ax=ax, bins=50, edgecolor="blue")
2、划分数据集
探索房屋目标销售价格作为分层抽样的可能性:
housingUse["price_doc"].hist(bins=70, edgecolor="blue")
housingUse["price_doc"].describe()
count 1.000000e+04
mean 6.400306e+06
std 4.376049e+06
min 1.900000e+05
25% 4.233045e+06
50% 5.700000e+06
75% 7.301552e+06
max 1.111111e+08
Name: price_doc, dtype: float64
import numpy as np
housingUse["income_cat"] = pd.cut(housingUse["price_doc"],
bins=[0., 4200000., 5700000., 7300000., 9000000., np.inf],
labels=[1, 2, 3, 4, 5])
housingUse["income_cat"].hist(edgecolor="black")
# housingUse["income_cat"] = np.ceil(housingUse["price_doc"] / 100000000.)
# housingUse["income_cat"].where(housingUse["price_doc"] < 9, 9.0, inplace=True)
# housingUse["income_cat"].hist(edgecolor="black")
housingUse.head()
id | timestamp | full_sq | life_sq | floor | max_floor | material | build_year | num_room | kitch_sq | ... | cafe_count_5000_price_4000 | cafe_count_5000_price_high | big_church_count_5000 | church_count_5000 | mosque_count_5000 | leisure_count_5000 | sport_count_5000 | market_count_5000 | price_doc | income_cat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 2011-08-20 | 43 | 27.0 | 4.0 | NaN | NaN | NaN | NaN | NaN | ... | 4 | 0 | 13 | 22 | 1 | 0 | 52 | 4 | 5850000 | 3 |
1 | 2 | 2011-08-23 | 34 | 19.0 | 3.0 | NaN | NaN | NaN | NaN | NaN | ... | 3 | 0 | 15 | 29 | 1 | 10 | 66 | 14 | 6000000 | 3 |
2 | 3 | 2011-08-27 | 43 | 29.0 | 2.0 | NaN | NaN | NaN | NaN | NaN | ... | 3 | 0 | 11 | 27 | 0 | 4 | 67 | 10 | 5700000 | 2 |
3 | 4 | 2011-09-01 | 89 | 50.0 | 9.0 | NaN | NaN | NaN | NaN | NaN | ... | 2 | 1 | 4 | 4 | 0 | 0 | 26 | 3 | 13100000 | 5 |
4 | 5 | 2011-09-05 | 77 | 77.0 | 4.0 | NaN | NaN | NaN | NaN | NaN | ... | 108 | 17 | 135 | 236 | 2 | 91 | 195 | 14 | 16331452 | 5 |
5 rows × 293 columns
housingUse["income_cat"].value_counts()/len(housing)
2 0.085754
1 0.080995
3 0.079387
5 0.048013
4 0.034032
Name: income_cat, dtype: float64
# 基于目标售价的分层抽样
from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housingUse, housingUse["income_cat"]):
strat_train_set = housingUse.loc[train_index]
strat_test_set = housingUse.loc[test_index]
# 完事以后,删除用于抽样而增加的那一列的信息
strat_train_set.drop('income_cat',axis=1,inplace=True)
strat_test_set.drop('income_cat',axis=1,inplace=True)
print("训练集大小:" + str(strat_train_set.shape) + "\n测试集大小:" + str(strat_test_set.shape))
训练集大小:(8000, 292)
测试集大小:(2000, 292)
3、探索数据集相关性
housing = strat_train_set.copy()
corr_matrix = housing.corr()
corr_matrix
id | full_sq | life_sq | floor | max_floor | material | build_year | num_room | kitch_sq | state | ... | cafe_count_5000_price_2500 | cafe_count_5000_price_4000 | cafe_count_5000_price_high | big_church_count_5000 | church_count_5000 | mosque_count_5000 | leisure_count_5000 | sport_count_5000 | market_count_5000 | price_doc | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
id | 1.000000 | 0.012074 | 0.121534 | 0.087292 | -0.013266 | -0.003899 | -0.060945 | -0.043295 | -0.056324 | -0.016977 | ... | -0.042916 | -0.033283 | -0.023342 | -0.059320 | -0.061332 | -0.002739 | -0.047711 | -0.163295 | -0.185277 | 0.012557 |
full_sq | 0.012074 | 1.000000 | 0.208281 | 0.076964 | 0.148558 | -0.085975 | -0.157111 | 0.700592 | 0.232639 | -0.067896 | ... | 0.010269 | 0.009751 | 0.012328 | 0.002156 | 0.004670 | 0.004773 | 0.004127 | -0.011730 | -0.031863 | 0.198563 |
life_sq | 0.121534 | 0.208281 | 1.000000 | 0.132810 | 0.047904 | -0.051019 | -0.042960 | 0.395729 | 0.143683 | -0.020084 | ... | 0.036823 | 0.041765 | 0.048429 | 0.019060 | 0.017277 | -0.004664 | 0.022581 | -0.069183 | -0.138471 | 0.360761 |
floor | 0.087292 | 0.076964 | 0.132810 | 1.000000 | 0.464896 | -0.077248 | -0.146436 | -0.033062 | 0.040441 | -0.012902 | ... | -0.048963 | -0.047175 | -0.038238 | -0.058347 | -0.059083 | -0.003731 | -0.059940 | -0.129194 | -0.147641 | 0.138167 |
max_floor | -0.013266 | 0.148558 | 0.047904 | 0.464896 | 1.000000 | 0.033218 | 0.006264 | -0.043865 | 0.124776 | -0.049004 | ... | -0.226110 | -0.223682 | -0.224157 | -0.194559 | -0.214707 | -0.194212 | -0.225493 | -0.241819 | -0.207358 | 0.112951 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
mosque_count_5000 | -0.002739 | 0.004773 | -0.004664 | -0.003731 | -0.194212 | -0.009277 | 0.065323 | 0.128494 | -0.064821 | 0.177200 | ... | 0.485745 | 0.453952 | 0.444877 | 0.448511 | 0.505794 | 1.000000 | 0.456811 | 0.439594 | 0.170313 | 0.129135 |
leisure_count_5000 | -0.047711 | 0.004127 | 0.022581 | -0.059940 | -0.225493 | 0.068069 | 0.023510 | 0.151868 | -0.007931 | 0.118095 | ... | 0.985805 | 0.974862 | 0.945952 | 0.965860 | 0.973137 | 0.456811 | 1.000000 | 0.806361 | 0.460463 | 0.179983 |
sport_count_5000 | -0.163295 | -0.011730 | -0.069183 | -0.129194 | -0.241819 | 0.122477 | 0.077763 | 0.155859 | -0.021454 | 0.223212 | ... | 0.820315 | 0.771293 | 0.743241 | 0.822253 | 0.845495 | 0.439594 | 0.806361 | 1.000000 | 0.729183 | 0.263404 |
market_count_5000 | -0.185277 | -0.031863 | -0.138471 | -0.147641 | -0.207358 | 0.053439 | 0.098520 | 0.062009 | -0.014324 | 0.296082 | ... | 0.429567 | 0.375837 | 0.345498 | 0.479993 | 0.512822 | 0.170313 | 0.460463 | 0.729183 | 1.000000 | 0.172720 |
price_doc | 0.012557 | 0.198563 | 0.360761 | 0.138167 | 0.112951 | -0.049962 | -0.020732 | 0.495512 | 0.125347 | 0.170591 | ... | 0.203641 | 0.184150 | 0.188428 | 0.182466 | 0.193230 | 0.129135 | 0.179983 | 0.263404 | 0.172720 | 1.000000 |
276 rows × 276 columns
# 跟售价最相关的特征
related_features = corr_matrix["price_doc"].sort_values(ascending=False)
print(related_features[0:11])
price_doc 1.000000
num_room 0.495512
life_sq 0.360761
sport_count_3000 0.289984
sport_count_2000 0.281693
trc_count_5000 0.267114
sport_count_5000 0.263404
sport_count_1500 0.254639
trc_count_3000 0.254040
sport_objects_raion 0.238254
trc_sqm_5000 0.233633
Name: price_doc, dtype: float64
top_10_features = list(related_features.index[0:11]) # 去掉特征A自身
print(top_10_features)
from pandas.plotting import scatter_matrix
scatter_matrix(housing[top_10_features], figsize=(24, 16))
# plt.scatter(housing[top_10_features.index[0]], housing['price_doc'], s=5)
# for f in top_10_features.index[1:]:
# plt.scatter(housing[f], housing['price_doc'], s=5)
# plt.xlabel("Top 10 Features")
# plt.ylabel("price_doc")
# 房屋数似乎是跟售价最相关的特征
housing.plot(kind="scatter", x="num_room", y="price_doc", alpha=0.1)
<AxesSubplot:xlabel='num_room', ylabel='price_doc'>
分离训练集和标签
# 分离训练集和标签
housing_train = strat_train_set.drop("price_doc", axis=1) #拷贝了一份,而不会改变原始数据。如果inplace=true那就会改变
housing_train_labels = strat_train_set["price_doc"].copy()
print("机器学习训练集:" + str(housing_train.shape) + "\n标签:" + str(housing_train_labels.shape))
机器学习训练集:(8000, 291)
标签:(8000,)
pd.isna(housing_train["max_floor"]).sum()
housing_train.dtypes
id int64
timestamp object
full_sq int64
life_sq float64
floor float64
...
church_count_5000 int64
mosque_count_5000 int64
leisure_count_5000 int64
sport_count_5000 int64
market_count_5000 int64
Length: 291, dtype: object
# # 获取所有文本类型特征
# text_features = housing_train.select_dtypes(include=[np.object])
# # 获取所有数值类型特征
# numeric_features = housing_train.select_dtypes(include=[np.number])
# # 输出结果
# print("文本类型特征:\n", text_features)
# print("数值类型特征:\n", numeric_features)
#获取所有文本类型特征的名字列表
text_feature_names = list(housing_train.select_dtypes(include=[np.object]).columns)
# 获取所有数值类型特征的名字列表
numeric_feature_names = list(housing_train.select_dtypes(include=[np.number]).columns)
# 输出结果
print("文本类型特征的名字列表:", text_feature_names)
print("数值类型特征的名字列表:", numeric_feature_names)
文本类型特征的名字列表: ['timestamp', 'product_type', 'sub_area', 'culture_objects_top_25', 'thermal_power_plant_raion', 'incineration_raion', 'oil_chemistry_raion', 'radiation_raion', 'railroad_terminal_raion', 'big_market_raion', 'nuclear_reactor_raion', 'detention_facility_raion', 'water_1line', 'big_road1_1line', 'railroad_1line', 'ecology']
数值类型特征的名字列表: ['id', 'full_sq', 'life_sq', 'floor', 'max_floor', 'material', 'build_year', 'num_room', 'kitch_sq', 'state', 'area_m', 'raion_popul', 'green_zone_part', 'indust_part', 'children_preschool', 'preschool_quota', 'preschool_education_centers_raion', 'children_school', 'school_quota', 'school_education_centers_raion', 'school_education_centers_top_20_raion', 'hospital_beds_raion', 'healthcare_centers_raion', 'university_top_20_raion', 'sport_objects_raion', 'additional_education_raion', 'culture_objects_top_25_raion', 'shopping_centers_raion', 'office_raion', 'full_all', 'male_f', 'female_f', 'young_all', 'young_male', 'young_female', 'work_all', 'work_male', 'work_female', 'ekder_all', 'ekder_male', 'ekder_female', '0_6_all', '0_6_male', '0_6_female', '7_14_all', '7_14_male', '7_14_female', '0_17_all', '0_17_male', '0_17_female', '16_29_all', '16_29_male', '16_29_female', '0_13_all', '0_13_male', '0_13_female', 'raion_build_count_with_material_info', 'build_count_block', 'build_count_wood', 'build_count_frame', 'build_count_brick', 'build_count_monolith', 'build_count_panel', 'build_count_foam', 'build_count_slag', 'build_count_mix', 'raion_build_count_with_builddate_info', 'build_count_before_1920', 'build_count_1921-1945', 'build_count_1946-1970', 'build_count_1971-1995', 'build_count_after_1995', 'ID_metro', 'metro_min_avto', 'metro_km_avto', 'metro_min_walk', 'metro_km_walk', 'kindergarten_km', 'school_km', 'park_km', 'green_zone_km', 'industrial_km', 'water_treatment_km', 'cemetery_km', 'incineration_km', 'railroad_station_walk_km', 'railroad_station_walk_min', 'ID_railroad_station_walk', 'railroad_station_avto_km', 'railroad_station_avto_min', 'ID_railroad_station_avto', 'public_transport_station_km', 'public_transport_station_min_walk', 'water_km', 'mkad_km', 'ttk_km', 'sadovoe_km', 'bulvar_ring_km', 'kremlin_km', 'big_road1_km', 'ID_big_road1', 'big_road2_km', 'ID_big_road2', 'railroad_km', 'zd_vokzaly_avto_km', 'ID_railroad_terminal', 'bus_terminal_avto_km', 'ID_bus_terminal', 'oil_chemistry_km', 'nuclear_reactor_km', 'radiation_km', 'power_transmission_line_km', 'thermal_power_plant_km', 'ts_km', 'big_market_km', 'market_shop_km', 'fitness_km', 'swim_pool_km', 'ice_rink_km', 'stadium_km', 'basketball_km', 'hospice_morgue_km', 'detention_facility_km', 'public_healthcare_km', 'university_km', 'workplaces_km', 'shopping_centers_km', 'office_km', 'additional_education_km', 'preschool_km', 'big_church_km', 'church_synagogue_km', 'mosque_km', 'theater_km', 'museum_km', 'exhibition_km', 'catering_km', 'green_part_500', 'prom_part_500', 'office_count_500', 'office_sqm_500', 'trc_count_500', 'trc_sqm_500', 'cafe_count_500', 'cafe_sum_500_min_price_avg', 'cafe_sum_500_max_price_avg', 'cafe_avg_price_500', 'cafe_count_500_na_price', 'cafe_count_500_price_500', 'cafe_count_500_price_1000', 'cafe_count_500_price_1500', 'cafe_count_500_price_2500', 'cafe_count_500_price_4000', 'cafe_count_500_price_high', 'big_church_count_500', 'church_count_500', 'mosque_count_500', 'leisure_count_500', 'sport_count_500', 'market_count_500', 'green_part_1000', 'prom_part_1000', 'office_count_1000', 'office_sqm_1000', 'trc_count_1000', 'trc_sqm_1000', 'cafe_count_1000', 'cafe_sum_1000_min_price_avg', 'cafe_sum_1000_max_price_avg', 'cafe_avg_price_1000', 'cafe_count_1000_na_price', 'cafe_count_1000_price_500', 'cafe_count_1000_price_1000', 'cafe_count_1000_price_1500', 'cafe_count_1000_price_2500', 'cafe_count_1000_price_4000', 'cafe_count_1000_price_high', 'big_church_count_1000', 'church_count_1000', 'mosque_count_1000', 'leisure_count_1000', 'sport_count_1000', 'market_count_1000', 'green_part_1500', 'prom_part_1500', 'office_count_1500', 'office_sqm_1500', 'trc_count_1500', 'trc_sqm_1500', 'cafe_count_1500', 'cafe_sum_1500_min_price_avg', 'cafe_sum_1500_max_price_avg', 'cafe_avg_price_1500', 'cafe_count_1500_na_price', 'cafe_count_1500_price_500', 'cafe_count_1500_price_1000', 'cafe_count_1500_price_1500', 'cafe_count_1500_price_2500', 'cafe_count_1500_price_4000', 'cafe_count_1500_price_high', 'big_church_count_1500', 'church_count_1500', 'mosque_count_1500', 'leisure_count_1500', 'sport_count_1500', 'market_count_1500', 'green_part_2000', 'prom_part_2000', 'office_count_2000', 'office_sqm_2000', 'trc_count_2000', 'trc_sqm_2000', 'cafe_count_2000', 'cafe_sum_2000_min_price_avg', 'cafe_sum_2000_max_price_avg', 'cafe_avg_price_2000', 'cafe_count_2000_na_price', 'cafe_count_2000_price_500', 'cafe_count_2000_price_1000', 'cafe_count_2000_price_1500', 'cafe_count_2000_price_2500', 'cafe_count_2000_price_4000', 'cafe_count_2000_price_high', 'big_church_count_2000', 'church_count_2000', 'mosque_count_2000', 'leisure_count_2000', 'sport_count_2000', 'market_count_2000', 'green_part_3000', 'prom_part_3000', 'office_count_3000', 'office_sqm_3000', 'trc_count_3000', 'trc_sqm_3000', 'cafe_count_3000', 'cafe_sum_3000_min_price_avg', 'cafe_sum_3000_max_price_avg', 'cafe_avg_price_3000', 'cafe_count_3000_na_price', 'cafe_count_3000_price_500', 'cafe_count_3000_price_1000', 'cafe_count_3000_price_1500', 'cafe_count_3000_price_2500', 'cafe_count_3000_price_4000', 'cafe_count_3000_price_high', 'big_church_count_3000', 'church_count_3000', 'mosque_count_3000', 'leisure_count_3000', 'sport_count_3000', 'market_count_3000', 'green_part_5000', 'prom_part_5000', 'office_count_5000', 'office_sqm_5000', 'trc_count_5000', 'trc_sqm_5000', 'cafe_count_5000', 'cafe_sum_5000_min_price_avg', 'cafe_sum_5000_max_price_avg', 'cafe_avg_price_5000', 'cafe_count_5000_na_price', 'cafe_count_5000_price_500', 'cafe_count_5000_price_1000', 'cafe_count_5000_price_1500', 'cafe_count_5000_price_2500', 'cafe_count_5000_price_4000', 'cafe_count_5000_price_high', 'big_church_count_5000', 'church_count_5000', 'mosque_count_5000', 'leisure_count_5000', 'sport_count_5000', 'market_count_5000']
4、开始处理数据
# 数据处理前的准备
from sklearn.base import BaseEstimator, TransformerMixin
# scikit-learn暂时不支持DataFrame,因此我们需要分别选择数值型和枚举型
class DataFrameSelector(BaseEstimator, TransformerMixin):
def __init__(self, attribute_names):
self.attribute_names = attribute_names
def fit(self, X, y=None):
return self
def transform(self, X):
return X[self.attribute_names].values
num_pipeline = Pipeline([
('selector', DataFrameSelector(numeric_feature_names)),
('imputer', SimpleImputer(strategy="median")),
('std_scaler', StandardScaler())
])
num_pipeline = Pipeline([
('selector', DataFrameSelector(numeric_feature_names)),
('imputer', SimpleImputer(strategy="median")),
('std_scaler', StandardScaler())
])
housing_train_num = housing_train.drop(housing_train[text_feature_names], axis=1)
housing_train_num.describe()
id | full_sq | life_sq | floor | max_floor | material | build_year | num_room | kitch_sq | state | ... | cafe_count_5000_price_1500 | cafe_count_5000_price_2500 | cafe_count_5000_price_4000 | cafe_count_5000_price_high | big_church_count_5000 | church_count_5000 | mosque_count_5000 | leisure_count_5000 | sport_count_5000 | market_count_5000 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 8000.000000 | 8000.000000 | 6220.000000 | 7873.000000 | 390.000000 | 390.000000 | 348.000000 | 390.000000 | 390.000000 | 312.000000 | ... | 8000.000000 | 8000.000000 | 8000.000000 | 8000.000000 | 8000.000000 | 8000.000000 | 8000.000000 | 8000.000000 | 8000.000000 | 8000.000000 |
mean | 4981.756625 | 53.796250 | 34.161415 | 8.002794 | 12.061538 | 1.892308 | 1956.942529 | 1.933333 | 7.156410 | 2.448718 | ... | 61.526625 | 30.692375 | 10.188875 | 1.665625 | 14.550750 | 29.439000 | 0.487750 | 8.170375 | 52.779125 | 6.010375 |
std | 2899.039994 | 64.166145 | 23.385401 | 5.412505 | 6.069970 | 1.435393 | 212.097490 | 0.785908 | 7.406913 | 0.733181 | ... | 119.186798 | 70.607171 | 27.197218 | 5.218378 | 27.970448 | 45.681075 | 0.608602 | 19.862469 | 44.979570 | 4.771961 |
min | 1.000000 | 5.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 2.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
25% | 2470.750000 | 38.000000 | 20.000000 | 4.000000 | 9.000000 | 1.000000 | 1965.000000 | 1.000000 | 5.000000 | 2.000000 | ... | 8.000000 | 2.000000 | 1.000000 | 0.000000 | 2.000000 | 8.000000 | 0.000000 | 0.000000 | 12.000000 | 1.000000 |
50% | 4982.500000 | 46.000000 | 30.000000 | 7.000000 | 12.000000 | 1.000000 | 1975.000000 | 2.000000 | 7.000000 | 3.000000 | ... | 24.000000 | 9.000000 | 2.000000 | 0.000000 | 7.000000 | 16.000000 | 0.000000 | 2.000000 | 47.000000 | 5.000000 |
75% | 7489.250000 | 61.000000 | 42.000000 | 11.000000 | 17.000000 | 2.000000 | 1995.000000 | 2.000000 | 9.000000 | 3.000000 | ... | 50.000000 | 20.000000 | 4.000000 | 1.000000 | 12.000000 | 28.000000 | 1.000000 | 6.000000 | 76.000000 | 10.000000 |
max | 10002.000000 | 5326.000000 | 802.000000 | 44.000000 | 40.000000 | 6.000000 | 2015.000000 | 5.000000 | 123.000000 | 4.000000 | ... | 636.000000 | 376.000000 | 147.000000 | 30.000000 | 151.000000 | 250.000000 | 2.000000 | 105.000000 | 218.000000 | 20.000000 |
8 rows × 275 columns
# housing[0].value_counts()
housing_train_num_str = num_pipeline.fit_transform(housing_train_num) # housing_num为无文本数值的训练数据
housing_num_str = pd.DataFrame(housing_train_num_str, columns=housing_train_num.columns)
housing_num_str.describe()
id | full_sq | life_sq | floor | max_floor | material | build_year | num_room | kitch_sq | state | ... | cafe_count_5000_price_1500 | cafe_count_5000_price_2500 | cafe_count_5000_price_4000 | cafe_count_5000_price_high | big_church_count_5000 | church_count_5000 | mosque_count_5000 | leisure_count_5000 | sport_count_5000 | market_count_5000 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | ... | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8.000000e+03 | 8000.000000 | 8.000000e+03 | 8.000000e+03 |
mean | 7.105427e-18 | -1.243450e-17 | -7.904788e-17 | -7.815970e-17 | -7.016610e-17 | -2.406964e-16 | -9.792167e-16 | -1.749711e-16 | 2.153833e-17 | 4.325429e-16 | ... | -1.554312e-17 | 4.085621e-17 | 4.884981e-17 | -1.509903e-17 | 3.463896e-17 | 7.993606e-18 | -6.217249e-17 | 0.000000 | 2.664535e-18 | 2.131628e-17 |
std | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | ... | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063e+00 | 1.000063 | 1.000063e+00 | 1.000063e+00 |
min | -1.718178e+00 | -7.605148e-01 | -1.606265e+00 | -1.487177e+00 | -8.967104e+00 | -1.174785e-01 | -4.453838e+01 | -5.731910e+00 | -4.289547e+00 | -1.101076e+01 | ... | -5.162524e-01 | -4.347192e-01 | -3.746527e-01 | -3.192044e-01 | -5.202512e-01 | -6.007020e-01 | -8.014771e-01 | -0.411373 | -1.173476e+00 | -1.259598e+00 |
25% | -8.662052e-01 | -2.461927e-01 | -5.913394e-01 | -7.423666e-01 | -2.241216e-03 | -1.174785e-01 | 1.772092e-02 | 1.868945e-02 | -4.667458e-03 | 1.196519e-01 | ... | -4.491267e-01 | -4.063917e-01 | -3.378820e-01 | -3.192044e-01 | -4.487427e-01 | -4.693484e-01 | -8.014771e-01 | -0.411373 | -9.066710e-01 | -1.050027e+00 |
50% | 2.564371e-04 | -1.215086e-01 | -1.563711e-01 | -1.837587e-01 | -2.241216e-03 | -1.174785e-01 | 1.772092e-02 | 1.868945e-02 | -4.667458e-03 | 1.196519e-01 | ... | -3.148752e-01 | -3.072454e-01 | -3.011112e-01 | -3.192044e-01 | -2.699714e-01 | -2.942102e-01 | -8.014771e-01 | -0.310674 | -1.284914e-01 | -2.117448e-01 |
75% | 8.649933e-01 | 1.122742e-01 | 1.819376e-01 | 5.610518e-01 | -2.241216e-03 | -1.174785e-01 | 1.772092e-02 | 1.868945e-02 | -4.667458e-03 | 1.196519e-01 | ... | -9.671663e-02 | -1.514442e-01 | -2.275697e-01 | -1.275620e-01 | -9.120016e-02 | -3.150298e-02 | 8.417359e-01 | -0.109277 | 5.162861e-01 | 8.361079e-01 |
max | 1.731800e+00 | 8.217002e+01 | 3.715424e+01 | 6.705738e+00 | 2.091577e+01 | 1.338579e+01 | 9.201230e-01 | 1.727049e+01 | 7.100191e+01 | 5.684856e+00 | ... | 4.820243e+00 | 4.890852e+00 | 5.030650e+00 | 5.430067e+00 | 4.878641e+00 | 4.828581e+00 | 2.484949e+00 | 4.875309 | 3.673472e+00 | 2.931813e+00 |
8 rows × 275 columns
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn. preprocessing import OneHotEncoder
from sklearn.pipeline import FeatureUnion
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
import time
# scikit-learn暂时不支持DataFrame,因此我们需要分别选择数值型和枚举型
class DataFrameSelector(BaseEstimator, TransformerMixin):
def __init__(self, attribute_names):
self.attribute_names = attribute_names
def fit(self, X, y=None):
return self
def transform(self, X):
return X[self.attribute_names].values
num_pipeline = Pipeline([
('selector', DataFrameSelector(numeric_feature_names)),
('imputer', SimpleImputer(strategy="median")),
('std_scaler', StandardScaler())
])
cat_pipeline = Pipeline(steps=[
('selector', DataFrameSelector(text_feature_names)),
('imputer', SimpleImputer(strategy='most_frequent')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
full_pipeline = FeatureUnion(transformer_list=[
("num_pipeline", num_pipeline),
("cat_pipeline", cat_pipeline)
])
full_pipeline_with_model = Pipeline(steps=[
("full_pipeline", full_pipeline),
("model", LinearRegression()),
])
param_grid = {
"full_pipeline__num_pipeline__imputer__strategy": ["mean", "median", "most_frequent"],
"full_pipeline__cat_pipeline__imputer__strategy": ["most_frequent", "constant"],
"full_pipeline__cat_pipeline__onehot__handle_unknown": ['error', 'ignore'],
"model": [LinearRegression(), SVR(), RandomForestRegressor()]
}
start = time.time()
# 创建 GridSearchCV 对象,进行参数搜索
grid_search = GridSearchCV(full_pipeline_with_model, param_grid=param_grid, cv=3)
grid_search.fit(housing_train, housing_train_labels)
end = time.time()
print("耗时:" + str(end-start))
# 打印最佳参数组合和模型得分
print("Best parameter combination: ", grid_search.best_params_)
print("Best model score: ", grid_search.best_score_)
F:\anaconda3\envs\TF2.1\lib\site-packages\sklearn\model_selection\_search.py:972: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan
nan nan nan -0.60319146 -0.02711282 0.52201065
-0.61663247 -0.02711263 0.52200177 -0.65495276 -0.02711175 0.52994882
nan nan nan nan nan nan
nan nan nan -0.60319146 -0.02711282 0.53401227
-0.61663247 -0.02711263 0.53067201 -0.65495276 -0.02711175 0.52880227]
category=UserWarning,
耗时:20075.620085954666
Best parameter combination: {'full_pipeline__cat_pipeline__imputer__strategy': 'constant', 'full_pipeline__cat_pipeline__onehot__handle_unknown': 'ignore', 'full_pipeline__num_pipeline__imputer__strategy': 'mean', 'model': RandomForestRegressor()}
# 分析每个属性(列)的重要程度(把他们对应起来)
best_pipeline = grid_search.best_estimator_
best_model = best_pipeline.named_steps['model']
feature_importances = best_model.feature_importances_
feature_importances = feature_importances[:len(numeric_feature_names)]
print(sorted(zip(feature_importances, numeric_feature_names), reverse=True))
[(0.3558050288728259, 'full_sq'), (0.0369022366376152, 'cafe_count_2000'), (0.030728634527189217, 'office_sqm_5000'), (0.019937082781744157, 'cafe_count_3000_price_1500'), (0.015537904861632186, 'cafe_count_3000'), (0.013112393705737276, 'floor'), (0.01302595118237522, 'life_sq'), (0.011152233036002608, 'id'), (0.010880138446129202, 'cafe_count_2000_price_2500'), (0.008877243813506767, 'sport_count_3000'), (0.008849575375988353, 'sadovoe_km'), (0.00847971413204643, 'cafe_count_1000'), (0.007812304011587292, 'build_count_monolith'), (0.007124466078378038, 'theater_km'), (0.00598955666238343, 'green_part_1000'), (0.005293732058797867, 'cafe_count_5000_price_4000'), (0.004837345161800656, 'cafe_count_3000_price_2500'), (0.004817345418725336, 'kindergarten_km'), (0.00461492176396289, 'hospital_beds_raion'), (0.0045375430504188555, 'sport_count_2000'), (0.004470754833849411, 'preschool_quota'), (0.004425261691208156, 'public_healthcare_km'), (0.004284497060241596, 'big_road1_km'), (0.004248784991517519, 'zd_vokzaly_avto_km'), (0.004136713219380025, 'cafe_sum_1500_min_price_avg'), (0.0040865715928336014, 'cafe_count_1500'), (0.004079397582177327, 'additional_education_km'), (0.004052374972389635, 'big_road2_km'), (0.003988470054377696, 'ID_metro'), (0.0037952582929236, 'build_count_panel'), (0.003774838216804793, 'workplaces_km'), (0.0037429183086013025, 'swim_pool_km'), (0.0037373970684597656, 'power_transmission_line_km'), (0.0036912929593904283, 'metro_km_avto'), (0.0036524212108699876, 'green_part_2000'), (0.003637247698047196, 'public_transport_station_min_walk'), (0.0035343553043008886, 'market_shop_km'), (0.0035300284308685064, 'prom_part_3000'), (0.003493039231479483, 'green_zone_km'), (0.0034719231362119367, 'water_treatment_km'), (0.0034219894204335654, 'green_part_1500'), (0.003329615073905508, 'catering_km'), (0.003314536975526, 'public_transport_station_km'), (0.0033046821247699238, 'thermal_power_plant_km'), (0.003278634047964668, 'radiation_km'), (0.003228769536659676, 'church_synagogue_km'), (0.003166425810752949, 'big_market_km'), (0.003121693741607176, 'prom_part_1500'), (0.003092860614699901, 'detention_facility_km'), (0.0030660751025299326, 'cafe_count_1000_price_high'), (0.0030641894052107197, 'cafe_count_5000_price_2500'), (0.002982240570342268, 'industrial_km'), (0.0029721100255484362, 'oil_chemistry_km'), (0.0029372748333046216, 'ice_rink_km'), (0.0029037018536741265, 'hospice_morgue_km'), (0.0028979672492565438, 'basketball_km'), (0.0028716241247858025, 'preschool_km'), (0.0028520765357115613, 'shopping_centers_km'), (0.002827192165055018, 'cafe_sum_3000_min_price_avg'), (0.002767545941323787, 'fitness_km'), (0.0027459730591502386, 'cafe_sum_3000_max_price_avg'), (0.002737267343525625, 'cafe_count_5000_na_price'), (0.002666128595370512, 'nuclear_reactor_km'), (0.0026443891031494643, 'water_km'), (0.0025825863149320084, 'metro_min_avto'), (0.002551003232936204, 'big_church_km'), (0.0025421906600374853, 'ID_big_road1'), (0.0025005218540241976, 'exhibition_km'), (0.0024592455179383373, 'railroad_station_avto_km'), (0.002446947457164698, 'school_km'), (0.00243477689783016, 'trc_sqm_5000'), (0.002388639752287746, 'cafe_sum_2000_min_price_avg'), (0.002373529486919849, 'cemetery_km'), (0.0023705833591715717, 'railroad_km'), (0.0023601865395306318, 'cafe_count_1500_price_1000'), (0.0022960336723069886, 'cafe_avg_price_1500'), (0.0022944322375373204, 'green_part_3000'), (0.0022678600686821377, 'cafe_avg_price_2000'), (0.0022403779397926764, 'sport_count_5000'), (0.002222779912811367, 'ttk_km'), (0.0022130206446797107, 'university_km'), (0.002207947463485604, 'railroad_station_avto_min'), (0.0021962332587115666, 'ts_km'), (0.002169777156575615, 'metro_km_walk'), (0.002149595498746991, 'incineration_km'), (0.002123650071105262, 'cafe_avg_price_3000'), (0.002099515576433346, 'ID_railroad_station_walk'), (0.002099417314723335, 'green_part_5000'), (0.0020798017252723726, 'trc_sqm_1000'), (0.0020712663570771847, 'bus_terminal_avto_km'), (0.002063120390980617, 'cafe_sum_2000_max_price_avg'), (0.0020253531204588756, 'cafe_sum_500_min_price_avg'), (0.002010741544731312, 'stadium_km'), (0.0019952258027357635, 'museum_km'), (0.0019808672104866024, 'green_part_500'), (0.0019596999429266713, 'trc_sqm_3000'), (0.0019418970137407457, 'trc_count_2000'), (0.0018829571414917402, 'office_count_1000'), (0.001860785317751558, 'park_km'), (0.0018573122257889987, 'ID_railroad_station_avto'), (0.0018502586755312586, 'office_km'), (0.0018502048528180166, 'sport_count_1500'), (0.0017950315340824433, 'cafe_count_3000_na_price'), (0.0017408916406401887, 'cafe_count_1500_price_4000'), (0.0017066627368712514, 'cafe_count_1000_price_1500'), (0.0017012006774767565, 'cafe_count_5000_price_1000'), (0.0016788047956672809, 'cafe_sum_5000_min_price_avg'), (0.0016591314676513353, 'office_count_5000'), (0.0016215611970843834, 'metro_min_walk'), (0.0016027121916663456, 'trc_sqm_2000'), (0.0015896148083924218, 'prom_part_1000'), (0.001546125572060984, 'mosque_km'), (0.0015365685956578443, 'office_sqm_1500'), (0.00152326409219622, 'full_all'), (0.0015204454965024228, 'office_sqm_3000'), (0.0015057001257621492, 'trc_count_3000'), (0.001494633569870594, 'prom_part_5000'), (0.0014741025308991666, 'cafe_count_3000_price_500'), (0.0014712578299238576, 'trc_sqm_1500'), (0.0014707171791260405, 'prom_part_2000'), (0.0014706187046799005, 'trc_count_1500'), (0.0014375284468118282, 'mkad_km'), (0.0014370105029316513, 'cafe_count_500'), (0.0014207990144875333, 'indust_part'), (0.0014038140502102637, 'cafe_count_1000_price_1000'), (0.0013930160501839928, 'cafe_count_5000_price_500'), (0.0013868644433262223, 'railroad_station_walk_min'), (0.0013698389206839422, 'cafe_count_2000_price_1000'), (0.0013394166555226875, 'cafe_count_1000_price_500'), (0.0013386965146843976, 'office_sqm_2000'), (0.0013326936809903197, 'build_count_1946-1970'), (0.0013128109452357497, 'cafe_count_3000_price_4000'), (0.0012984499017887556, 'cafe_sum_5000_max_price_avg'), (0.0012800996165303555, 'cafe_count_1500_price_1500'), (0.0012761934657111557, 'build_count_after_1995'), (0.0012718569201853303, 'big_church_count_2000'), (0.00127058330811667, 'cafe_count_2000_price_500'), (0.0012638220027016128, 'ID_railroad_terminal'), (0.0012613567555310671, 'railroad_station_walk_km'), (0.0012434356062678146, 'green_zone_part'), (0.0012174455328585047, 'build_count_1971-1995'), (0.001214217771910042, 'sport_count_1000'), (0.0012140281465657175, 'cafe_avg_price_500'), (0.0012081753679400063, 'bulvar_ring_km'), (0.0012061047238333977, 'cafe_count_5000'), (0.0011839159920214273, 'healthcare_centers_raion'), (0.0011771248889729195, 'leisure_count_3000'), (0.001160775840178709, 'office_sqm_1000'), (0.0011551690564398636, 'cafe_count_1500_price_500'), (0.001132459081265372, 'cafe_avg_price_5000'), (0.001122449562282057, 'cafe_avg_price_1000'), (0.0011135964866840426, 'office_count_2000'), (0.0011111112115422372, 'cafe_sum_1500_max_price_avg'), (0.001069842749907959, 'school_quota'), (0.00106774249382398, 'cafe_count_2000_price_1500'), (0.0010602815927741205, 'cafe_sum_1000_max_price_avg'), (0.001044900091696954, 'cafe_sum_500_max_price_avg'), (0.001034667207937732, 'church_count_5000'), (0.0010301706062677523, 'cafe_count_5000_price_high'), (0.0010092727950096452, 'area_m'), (0.0009682554635303513, 'raion_build_count_with_material_info'), (0.0009587938492049883, '16_29_male'), (0.0009555688936662105, 'cafe_count_500_price_1500'), (0.0009535214961759879, 'ID_big_road2'), (0.000942206565069341, 'prom_part_500'), (0.0009311442185289545, 'cafe_count_500_price_2500'), (0.000925944901707187, 'cafe_count_3000_price_1000'), (0.000917036109269484, 'female_f'), (0.0009166333428757389, 'office_raion'), (0.0009037218668677034, 'trc_count_5000'), (0.0009017982067247552, 'build_count_brick'), (0.0009016525553024209, '16_29_female'), (0.0008981451459412267, 'additional_education_raion'), (0.0008892925799033124, 'sport_objects_raion'), (0.0008869222028668247, 'leisure_count_500'), (0.0008649458196326145, 'market_count_2000'), (0.0008606533432220762, 'cafe_count_1500_na_price'), (0.0008482599181997117, 'work_male'), (0.0008313163900783171, 'raion_build_count_with_builddate_info'), (0.0008176531565173631, 'church_count_2000'), (0.000798035936345229, 'leisure_count_1000'), (0.0007978367110980138, 'church_count_3000'), (0.0007886108904256832, 'trc_sqm_500'), (0.0007589258852283437, 'office_count_3000'), (0.0007532881423325608, '16_29_all'), (0.0007492697593058538, 'trc_count_1000'), (0.0007407201224415953, 'big_church_count_5000'), (0.0007358415153196833, 'office_count_1500'), (0.0007343517710255492, 'shopping_centers_raion'), (0.0007197409542981459, 'cafe_count_2000_na_price'), (0.0007185895688121112, 'school_education_centers_raion'), (0.0007028017150241994, 'cafe_count_500_price_500'), (0.0007016613633910146, 'cafe_sum_1000_min_price_avg'), (0.0006943690591683538, 'market_count_5000'), (0.0006844796452933984, 'office_sqm_500'), (0.0006764053215622059, 'male_f'), (0.0006759856898131497, 'cafe_count_2000_price_4000'), (0.0006737379724652002, 'church_count_500'), (0.0006643541654161514, 'leisure_count_2000'), (0.0006594348467692361, 'cafe_count_5000_price_1500'), (0.0006535942774188417, 'cafe_count_1000_price_2500'), (0.0006450610078269494, 'cafe_count_1000_na_price'), (0.0006346299560631147, 'big_church_count_3000'), (0.0006268199748514777, 'market_count_1000'), (0.0006136616922540099, 'cafe_count_1500_price_2500'), (0.0006101020772724982, 'build_count_1921-1945'), (0.000605333476588235, 'build_count_slag'), (0.0006025808857398215, 'sport_count_500'), (0.0005913398638573828, 'leisure_count_1500'), (0.0005893053464770819, '0_13_all'), (0.0005726863992438555, 'office_count_500'), (0.0005686382393212504, 'cafe_count_500_price_1000'), (0.0005335367189856385, 'preschool_education_centers_raion'), (0.0005281848859666116, 'build_count_block'), (0.0005053507999611831, 'church_count_1500'), (0.000503907065993229, 'young_female'), (0.0004944574646751173, 'cafe_count_500_price_4000'), (0.0004940955145651672, 'build_count_before_1920'), (0.0004563872071786454, '7_14_female'), (0.0004467097825214889, 'ID_bus_terminal'), (0.00044440004204567495, 'kremlin_km'), (0.0004245073138422882, 'market_count_3000'), (0.00041927115333189414, 'church_count_1000'), (0.00039153288360794847, 'ekder_male'), (0.0003894559280018075, 'leisure_count_5000'), (0.00035520364499142795, 'cafe_count_500_na_price'), (0.0003426960631525808, 'cafe_count_3000_price_high'), (0.0003358805974856705, 'ekder_female'), (0.0003354447960831321, 'market_count_1500'), (0.0003161403348918282, 'work_female'), (0.00031513222155404733, 'trc_count_500'), (0.0002789002146380332, 'big_church_count_1000'), (0.00026355257451541296, 'cafe_count_1000_price_4000'), (0.0002592117928358345, 'work_all'), (0.000246114632595913, 'ekder_all'), (0.0002355341717062321, 'mosque_count_5000'), (0.00023529864017649287, 'raion_popul'), (0.0002271843949564929, 'children_school'), (0.0002250069306485336, 'kitch_sq'), (0.00022431385558648625, 'children_preschool'), (0.00022412795110272134, 'max_floor'), (0.00022004001831562965, '0_13_male'), (0.0002177633191441963, '7_14_all'), (0.000211450005150454, 'school_education_centers_top_20_raion'), (0.00021129814290958885, 'build_count_wood'), (0.00020721481519426333, 'build_year'), (0.0002011649065965163, 'big_church_count_1500'), (0.00019640473122088542, '0_6_all'), (0.0001959025906314739, 'build_count_frame'), (0.00019056850712193235, 'big_church_count_500'), (0.00018503977876596978, '0_6_male'), (0.00018348735544191697, '0_17_all'), (0.0001816468582567304, '7_14_male'), (0.0001764278041528311, 'university_top_20_raion'), (0.00017137088658084054, 'cafe_count_2000_price_high'), (0.0001689115793605786, '0_6_female'), (0.00016185056709061189, 'market_count_500'), (0.00015899427406858044, 'num_room'), (0.0001502293154301887, '0_17_male'), (0.00014287804669518018, 'young_all'), (0.00013951399639071044, 'state'), (0.00013783866519237893, 'material'), (0.00013297476626663562, '0_17_female'), (0.00011858523670534393, 'young_male'), (0.00010142842124175284, 'culture_objects_top_25_raion'), (9.525954968205083e-05, '0_13_female'), (9.499848473560952e-05, 'mosque_count_1000'), (9.488543927847641e-05, 'cafe_count_1500_price_high'), (7.633784366627806e-05, 'mosque_count_3000'), (5.9661958682616646e-05, 'build_count_foam'), (4.7001210819952146e-05, 'build_count_mix'), (3.5218669751835894e-05, 'mosque_count_1500'), (3.5050631496159006e-05, 'mosque_count_500'), (2.257946417881287e-05, 'mosque_count_2000'), (1.2116849876895503e-05, 'cafe_count_500_price_high')]
# 训练集上的误差
from sklearn.metrics import mean_squared_error
final_model = grid_search.best_estimator_
housing_train_prepare = full_pipeline.fit_transform(housing_train)
Train_predictions = final_model.predict(housing_train)
Train_mse = mean_squared_error(housing_train_labels, Train_predictions)
Train_rmse = np.sqrt(Train_mse)
print(Train_rmse)
# 房价前25%:4.233045e+06
1132793.1642235247
# 测试集上的误差
X_test = strat_test_set.drop("price_doc", axis=1)
Y_test = strat_test_set["price_doc"].copy()
final_predictions = final_model.predict(X_test)
final_mse = mean_squared_error(Y_test, final_predictions)
final_rmse = np.sqrt(final_mse)
print(final_rmse)
2562534.086991416
# rmse是泛化误差,估计的精确度可以使用scipy计算置信区间
from scipy import stats
confidence = 0.95
squared_errors = (final_predictions-Y_test)**2
num = np.sqrt(stats.t.interval(confidence, len(squared_errors)-1,
loc=squared_errors.mean(),
scale = stats.sem(squared_errors)))
print(num)
[2206930.88712145 2874476.98781104]
如果跑完了,就从In[21]开始:housing_num->housing.describe()->num_pipline.fit_transform(housing_num) ->housing.describe()
看看有没有填成功。然后处理文本,看看增加了多少个one-hot维度
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
# 定义数据准备和模型训练的 Pipeline
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer()),
('scaler', StandardScaler())
])
categorical_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='most_frequent')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_feature_names),
('cat', categorical_transformer, text_feature_names)
])
pipeline = Pipeline(steps=[('preprocessor', preprocessor),
('regressor', LinearRegression())])
# 定义参数网格用于搜索最佳参数组合
param_grid = {
'preprocessor__num__imputer__strategy': ['mean', 'median', 'most_frequent', None],
'preprocessor__cat__imputer__strategy': ['mean', 'median', 'most_frequent', None],
'preprocessor__cat__onehot__handle_unknown': ['error', 'ignore'],
'regressor': [LinearRegression(), SVR(), RandomForestRegressor()],
}
# print("hello")
import time
start = time.time()
# 创建 GridSearchCV 对象,进行参数搜索
grid_search = GridSearchCV(pipeline, param_grid=param_grid, cv=3)
grid_search.fit(housing_train, housing_train_labels)
end = time.time()
print("耗时:" + str(end-start))
# 打印最佳参数组合和模型得分
print("Best parameter combination: ", grid_search.best_params_)
print("Best model score: ", grid_search.best_score_)
# 耗时:7641.132771015167
# Best parameter combination: {'preprocessor__cat__imputer__strategy': 'most_frequent', 'preprocessor__cat__onehot__handle_unknown': 'ignore', 'preprocessor__num__imputer__strategy': 'median', 'regressor': RandomForestRegressor()}
# Best model score: 0.5281557991343143
F:\anaconda3\envs\TF2.1\lib\site-packages\sklearn\model_selection\_search.py:972: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
-0.60319146 -0.02711282 0.52238756 -0.61663247 -0.02711263 0.5281558
-0.65495276 -0.02711175 0.52238873 nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan]
category=UserWarning,
耗时:7641.132771015167
Best parameter combination: {'preprocessor__cat__imputer__strategy': 'most_frequent', 'preprocessor__cat__onehot__handle_unknown': 'ignore', 'preprocessor__num__imputer__strategy': 'median', 'regressor': RandomForestRegressor()}
Best model score: 0.5281557991343143