机器学习实战1.4之逻辑回归:多分类问题解决方法

主要思想:将多分类问题转换为多个二分类问题。

import pandas as pd
import matplotlib.pyplot as plt
#由于数据集中每列数据没有标签,因此需要先手动添加,且用空格来隔开
columns = [
    'mpg', 'cylinders', 'displacement', 'horsepower', 'weight', 'acceleration',
    'model year', 'origin', 'car name'
]
cars = pd.read_table('./data/auto-mpg.data', delim_whitespace=True, names = columns)
print(cars.head())
    mpg  cylinders  displacement horsepower  weight  acceleration  model year  \
0  18.0          8         307.0      130.0  3504.0          12.0          70   
1  15.0          8         350.0      165.0  3693.0          11.5          70   
2  18.0          8         318.0      150.0  3436.0          11.0          70   
3  16.0          8         304.0      150.0  3433.0          12.0          70   
4  17.0          8         302.0      140.0  3449.0          10.5          70   

   origin                   car name  
0       1  chevrolet chevelle malibu  
1       1          buick skylark 320  
2       1         plymouth satellite  
3       1              amc rebel sst  
4       1                ford torino  

pandas.get_dummies()函数生成多分类标签

dummy_cylinders = pd.get_dummies(cars['cylinders'], prefix='cyl')
#print(dummy_cylinders.head())
cars = pd.concat([cars, dummy_cylinders], axis=1)
#print(cars.head())
dummy_years = pd.get_dummies(cars['model year'], prefix='year')
cars = pd.concat([cars, dummy_years], axis=1)
cars = cars.drop('model year', axis=1)
cars = cars.drop('cylinders', axis=1)
print(cars.head())
    mpg  displacement horsepower  weight  acceleration  origin  \
0  18.0         307.0      130.0  3504.0          12.0       1   
1  15.0         350.0      165.0  3693.0          11.5       1   
2  18.0         318.0      150.0  3436.0          11.0       1   
3  16.0         304.0      150.0  3433.0          12.0       1   
4  17.0         302.0      140.0  3449.0          10.5       1   

                    car name  cyl_3  cyl_4  cyl_5   ...     year_73  year_74  \
0  chevrolet chevelle malibu      0      0      0   ...           0        0   
1          buick skylark 320      0      0      0   ...           0        0   
2         plymouth satellite      0      0      0   ...           0        0   
3              amc rebel sst      0      0      0   ...           0        0   
4                ford torino      0      0      0   ...           0        0   

   year_75  year_76  year_77  year_78  year_79  year_80  year_81  year_82  
0        0        0        0        0        0        0        0        0  
1        0        0        0        0        0        0        0        0  
2        0        0        0        0        0        0        0        0  
3        0        0        0        0        0        0        0        0  
4        0        0        0        0        0        0        0        0  

[5 rows x 25 columns]
import numpy as np
shuffled_rows = np.random.permutation(cars.index)
shuffled_cars = cars.iloc[shuffled_rows]
highest_train_row = int(cars.shape[0] * .70)
train = shuffled_cars.iloc[0:highest_train_row]
test = shuffled_cars.iloc[highest_train_row:]
from sklearn.linear_model import LogisticRegression
unique_origins = cars['origin'].unique()
unique_origins.sort()

models = {}
features = [c for c in train.columns if c.startswith('cyl') or c.startswith('year')]

for origin in unique_origins:
    model = LogisticRegression()
    
    X_train = train[features]
    y_train = train['origin'] == origin

    model.fit(X_train, y_train)
    models[origin] = model
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
testing_probs = pd.DataFrame(columns=unique_origins)  
print(testing_probs)

for origin in unique_origins:
    # Select testing features.
    X_test = test[features]   
    # Compute probability of observation being in the origin.
    testing_probs[origin] = models[origin].predict_proba(X_test)[:,1]
print(testing_probs)
Empty DataFrame
Columns: [1, 2, 3]
Index: []
            1         2         3
0    0.582297  0.126099  0.306066
1    0.817356  0.067122  0.129694
2    0.272403  0.469606  0.262321
3    0.582297  0.126099  0.306066
4    0.365031  0.282050  0.337488
5    0.321619  0.321747  0.347881
6    0.265345  0.497703  0.247373
7    0.272403  0.469606  0.262321
8    0.836376  0.072543  0.104897
9    0.582297  0.126099  0.306066
10   0.323216  0.487932  0.189434
11   0.821074  0.068240  0.123570
12   0.817356  0.067122  0.129694
13   0.327120  0.325627  0.335424
14   0.316809  0.166039  0.514284
15   0.959333  0.031030  0.026298
16   0.323235  0.412184  0.256623
17   0.582297  0.126099  0.306066
18   0.973135  0.014010  0.037249
19   0.264071  0.355003  0.384353
20   0.323216  0.487932  0.189434
21   0.265345  0.497703  0.247373
22   0.582297  0.126099  0.306066
23   0.323235  0.412184  0.256623
24   0.582297  0.126099  0.306066
25   0.854075  0.086848  0.074716
26   0.316809  0.166039  0.514284
27   0.836376  0.072543  0.104897
28   0.264071  0.355003  0.384353
29   0.582297  0.126099  0.306066
..        ...       ...       ...
90   0.351290  0.340231  0.295533
91   0.818445  0.126277  0.061284
92   0.772058  0.077050  0.148501
93   0.382735  0.385388  0.224242
94   0.959333  0.031030  0.026298
95   0.975043  0.022176  0.021483
96   0.967840  0.024734  0.025550
97   0.958345  0.040700  0.017952
98   0.382735  0.385388  0.224242
99   0.957909  0.034598  0.024356
100  0.321619  0.321747  0.347881
101  0.958345  0.040700  0.017952
102  0.272403  0.469606  0.262321
103  0.316809  0.166039  0.514284
104  0.316809  0.166039  0.514284
105  0.975043  0.022176  0.021483
106  0.582297  0.126099  0.306066
107  0.323235  0.412184  0.256623
108  0.975043  0.022176  0.021483
109  0.814029  0.029313  0.228263
110  0.316809  0.166039  0.514284
111  0.821074  0.068240  0.123570
112  0.959333  0.031030  0.026298
113  0.323235  0.412184  0.256623
114  0.817356  0.067122  0.129694
115  0.382735  0.385388  0.224242
116  0.814029  0.029313  0.228263
117  0.265345  0.497703  0.247373
118  0.323216  0.487932  0.189434
119  0.836376  0.072543  0.104897

[120 rows x 3 columns]

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值