机器学习笔记代码分析(一):线性模型

开篇

这里是sklearn与tf机器学习实用指南第一章节的示例代码分析。主要是一个线性模型,分析gdp和幸福指数的关系。

数据处理

源码和数据

首先是幸福指数的数据

oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
oecd_bli.head(2)

目标是提取出来我们想要的数据,我们新获取的表示以country为行,indicator为列的表,示例数据如下:

out:
Indicator  Air pollution  Assault rate  Consultation on rule-making  \
Country                                                               
Australia           13.0           2.1                         10.5   
Austria             27.0           3.4                          7.1   

Indicator  Dwellings without basic facilities  Educational attainment  \
Country                                                                 
Australia                                 1.1                    76.0   
Austria                                   1.0                    83.0   

Indicator  Employees working very long hours  Employment rate  Homicide rate  \
Country                                                                        
Australia                              14.02             72.0            0.8   
Austria                                 7.61             72.0            0.4   

Indicator  Household net adjusted disposable income  \
Country                                               
Australia                                   31588.0   
Austria                                     31173.0   

Indicator  Household net financial wealth         ...          \
Country                                           ...           
Australia                         47657.0         ...           
Austria                           49887.0         ...           

Indicator  Long-term unemployment rate  Personal earnings  \
Country                                                     
Australia                         1.08            50449.0   
Austria                           1.19            45199.0   

Indicator  Quality of support network  Rooms per person  Self-reported health  \
Country                                                                         
Australia                        92.0               2.3                  85.0   
Austria                          89.0               1.6                  69.0   

Indicator  Student skills  Time devoted to leisure and personal care  \
Country                                                                
Australia           512.0                                      14.41   
Austria             500.0                                      14.46   

Indicator  Voter turnout  Water quality  Years in education  
Country                                                      
Australia           93.0           91.0                19.4  
Austria             75.0           94.0                17.0  

[2 rows x 24 columns]
In [11]: oecd_bli["Life satisfaction"].head()
Out[11]: 
Country
Australia    7.3
Austria      6.9
Belgium      6.9
Brazil       7.0
Canada       7.3
Name: Life satisfaction, dtype: float64

接着处理gdp的数据

gdp_per_capita = pd.read_csv(datapath+"gdp_per_capita.csv", thousands=',', delimiter='\t',
                             encoding='latin1', na_values="n/a")
gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
gdp_per_capita.set_index("Country", inplace=True)
gdp_per_capita.head(2)

这边我就不放出输出了,大家可以直接写写看看结果是怎么样,主要就是替换了一下column的名字,让城市成为index,目的是为了融合这两个表。

下面是融合表的代码

full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)
full_country_stats.sort_values(by="GDP per capita", inplace=True)
full_country_stats

输出在ipython里面显示比较诡异,还是暂时不放了。

In [19]: full_country_stats[["GDP per capita", 'Life satisfaction']].loc["United States"]
Out[19]: 
GDP per capita       55805.204
Life satisfaction        7.200
Name: United States, dtype: float64

这里就看看出我们表的目的啦

我们合并的这个表country有36个

简单的划分一下数据

remove_indices = [0, 1, 6, 8, 33, 34, 35]
keep_indices = list(set(range(36)) - set(remove_indices))

sample_data = full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]
missing_data = full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[remove_indices]

下面画一下图

sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(5,3))
plt.axis([0, 60000, 0, 10])
position_text = {
    "Hungary": (5000, 1),
    "Korea": (18000, 1.7),
    "France": (29000, 2.4),
    "Australia": (40000, 3.0),
    "United States": (52000, 3.8),
}
for country, pos_text in position_text.items():
    pos_data_x, pos_data_y = sample_data.loc[country]
    country = "U.S." if country == "United States" else country
    plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,
            arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))
    plt.plot(pos_data_x, pos_data_y, "ro")
save_fig('money_happy_scatterplot')
plt.show()

代码过得差不多,放上完整的代码

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn

def prepare_country_stats(oecd_bli, gdp_per_capita):
    oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
    oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
    gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
    gdp_per_capita.set_index("Country", inplace=True)
    full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,
                                  left_index=True, right_index=True)
    full_country_stats.sort_values(by="GDP per capita", inplace=True)
    remove_indices = [0, 1, 6, 8, 33, 34, 35]
    keep_indices = list(set(range(36)) - set(remove_indices))
    return full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]

# Load the data
oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv(datapath + "gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")

# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

# Visualize the data
country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
plt.show()

# Select a linear model
model = sklearn.linear_model.LinearRegression()

# Train the model
model.fit(X, y)

# Make a prediction for Cyprus
X_new = [[22587]]  # Cyprus' GDP per capita
print(model.predict(X_new)) # outputs [[ 5.96242338]]

这边主要是处理数据比较麻烦。

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值