基于随机森林模型的红酒品质分析

一、数据获取

​ 数据集:Wine Quality Data Set UCI葡萄酒数据集https://archive.ics.uci.edu/ml/datasets/wine+quality

​ 通过网站上数据集的摘要了解数据集的基本情况吗,发现UCI葡萄酒数据集包括两份:葡萄牙北部的红色和白色葡萄酒样本

​ 该样本常用于数据分析和机器学习分类等任务

​ 选择红葡萄酒数据集进行分析

1.1数据集基本信息

Attribute Information:			#数据集中各属性的说明

For more information, read [Cortez et al., 2009].

Input variables (based on physicochemical tests):			#输入变量(特征属性),基于物理化学测试
1 - fixed acidity
2 - volatile acidity
3 - citric acid
4 - residual sugar				#残糖
5 - chlorides
6 - free sulfur dioxide
7 - total sulfur dioxide
8 - density
9 - pH							#pH值
10 - sulphates
11 - alcohol					#酒精度

Output variable (based on sensory data):					#输出变量(目标属性)
12 - quality (score between 0 and 10)		#葡萄酒的质量评分

1.2数据具体情况

在这里插入图片描述

1.3导入数据集

​ 数据有表头,数据间用;隔开

​ 利用pandas完成数据的读取和预处理

​ 利用pandas模块中的read_csv()函数,并将其参数sep的值设为’;'就可以读取数据

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
import warnings
warnings.filterwarnings('ignore') 

try:        #读取数据
    wine = pd.read_csv('winequality-red.csv', sep = ';')    #将数据存在wine中,wine为DataFrame对象
except:
    print("Cannot find the file!")

二、预处理和探索

2.1查看数据基本情况

wine.info()  #查看数据基本情况
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1599 entries, 0 to 1598
Data columns (total 12 columns):
fixed acidity           1599 non-null float64
volatile acidity        1599 non-null float64
citric acid             1599 non-null float64
residual sugar          1599 non-null float64
chlorides               1599 non-null float64
free sulfur dioxide     1599 non-null float64
total sulfur dioxide    1599 non-null float64
density                 1599 non-null float64
pH                      1599 non-null float64
sulphates               1599 non-null float64
alcohol                 1599 non-null float64
quality                 1599 non-null int64
dtypes: float64(11), int64(1)
memory usage: 150.0 KB

2.2处理数据集

​ 检查重复记录duplicated()检查Series或DataFrame对象是否有重复记录——有True,无False,返回结果用sum()方法计算总和就能获得重复的行数

wine.duplicated().sum()  #检查DataFrame是否有重复记录,并用sum()计算重复行数
Out[3]: 240

wine=wine.drop_duplicates()  #若有重复记录删除并重新赋值给wine对象

wine
Out[5]: 
      fixed acidity  volatile acidity   ...     alcohol  quality
0               7.4             0.700   ...         9.4        5
1               7.8             0.880   ...         9.8        5
2               7.8             0.760   ...         9.8        5
3              11.2             0.280   ...         9.8        6
5               7.4             0.660   ...         9.4        5
6               7.9             0.600   ...         9.4        5
7               7.3             0.650   ...        10.0        7
8               7.8             0.580   ...         9.5        7
9               7.5             0.500   ...        10.5        5
10              6.7             0.580   ...         9.2        5
12              5.6             0.615   ...         9.9        5
13              7.8             0.610   ...         9.1        5
14              8.9             0.620   ...         9.2        5
15              8.9             0.620   ...         9.2        5
16              8.5             0.280   ...        10.5        7
17              8.1             0.560   ...         9.3        5
18              7.4             0.590   ...         9.0        4
19              7.9             0.320   ...         9.2        6
20              8.9             0.220   ...         9.4        6
21              7.6             0.390   ...         9.7        5
22              7.9             0.430   ...         9.5        5
23              8.5             0.490   ...         9.4        5
24              6.9             0.400   ...         9.7        6
25              6.3             0.390   ...         9.3        5
26              7.6             0.410   ...         9.5        5
28              7.1             0.710   ...         9.4        5
29              7.8             0.645   ...         9.8        6
30              6.7             0.675   ...        10.1        5
31              6.9             0.685   ...        10.6        6
32              8.3             0.655   ...         9.8        5
            ...               ...   ...         ...      ...
1566            6.7             0.160   ...        11.2        6
1568            7.0             0.560   ...         9.2        5
1569            6.2             0.510   ...        11.5        6
1570            6.4             0.360   ...        12.4        6
1571            6.4             0.380   ...        11.1        6
1572            7.3             0.690   ...         9.5        5
1573            6.0             0.580   ...        12.5        6
1574            5.6             0.310   ...        10.5        6
1575            7.5             0.520   ...        11.8        6
1576            8.0             0.300   ...        10.8        6
1577            6.2             0.700   ...        11.9        6
1578            6.8             0.670   ...        11.3        6
1579            6.2             0.560   ...        11.3        5
1580            7.4             0.350   ...        11.9        6
1582            6.1             0.715   ...        11.9        5
1583            6.2             0.460   ...         9.8        5
1584            6.7             0.320   ...        11.6        7
1585            7.2             0.390   ...        11.5        6
1586            7.5             0.310   ...        11.4        6
1587            5.8             0.610   ...        10.9        6
1588            7.2             0.660   ...        12.8        6
1589            6.6             0.725   ...         9.2        5
1590            6.3             0.550   ...        11.6        6
1591            5.4             0.740   ...        11.6        6
1592            6.3             0.510   ...        11.0        6
1593            6.8             0.620   ...         9.5        6
1594            6.2             0.600   ...        10.5        5
1595            5.9             0.550   ...        11.2        6
1597            5.9             0.645   ...        10.2        5
1598            6.0             0.310   ...        11.0        6

[1359 rows x 12 columns]

​ 简单查看目标属性quality,并查看quality属性每一类的分布情况,发现符合正态分布

wine.describe()  #查看数据基本信息
Out[6]: 
       fixed acidity  volatile acidity     ...           alcohol      quality
count    1359.000000       1359.000000     ...       1359.000000  1359.000000
mean        8.310596          0.529478     ...         10.432315     5.623252
std         1.736990          0.183031     ...          1.082065     0.823578
min         4.600000          0.120000     ...          8.400000     3.000000
25%         7.100000          0.390000     ...          9.500000     5.000000
50%         7.900000          0.520000     ...         10.200000     6.000000
75%         9.200000          0.640000     ...         11.100000     6.000000
max        15.900000          1.580000     ...         14.900000     8.000000

[8 rows x 12 columns]

wine.quality.value_counts()  #查看quality属性具体每一类有多少个值
Out[8]: 
5    577
6    535
7    167
4     53
8     17
3     10
Name: quality, dtype: int64

2.3探索特征属性和目标属性的相关性

​ 通过相关性发现与volatile acidityalcohol相关性比较大,前者为负相关,后者为正相关

​ 再通过绘图查看每个quality值对应的volatile acidityalcohol属性的均值的分布情况,直观的查看他们之间的相关性——使用seaborn模块的barplot()函数来处理

wine.corr().quality  #wine与属性之间的相关性
Out[11]: 
fixed acidity           0.119024
volatile acidity       -0.395214
citric acid             0.228057
residual sugar          0.013640
chlorides              -0.130988
free sulfur dioxide    -0.050463
total sulfur dioxide   -0.177855
density                -0.184252
pH                     -0.055245
sulphates               0.248835
alcohol                 0.480343
quality                 1.000000
Name: quality, dtype: float64
sns.barplot(x='quality',y='volatile acidity',data=wine)  #通过绘图查看均值的分布情况,了解相关性

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-deS3JuHd-1680958755074)(D:\typora\photo\volatile acidity.png)]

​ 看到quality为8对应volatile acidity所有值的均值接近0.6,综合全部发现volatile acidity值越高,quality值越低

sns.barplot(x='quality',y='alcohol',data=wine)  #通过绘图查看均值的分布情况,了解相关性

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JbQFA3RY-1680958755074)(D:\typora\photo\alcohol.png)]

​ 看到quality为8对应alcohol所有值的均值接近12,综合全部发现alcohol值越高,quality值越高

​ 属性与记录较少,不做数据归约

2.4清洗数据

​ 若有一批新的红葡萄酒数据,通过物理化学测试得知fixed acidity、volatile acid、citric acid、residual sugar、chlorides、free sulfur dioxide、total sulfur dioxide、density 、pH、sulphates、alcohol去求quality

​ 从数据分析得知quality值主要为6个3~8,类别较多,分类难度大,将多分类问题改为二分类问题

​ 可将quality值[3,8]划分为两部分:[3,6]为质量一般,[7,8]质量很好,进行二值化处理,简化问题,判断新酒是属于质量一般/很好

​ 以下程序将quality值进行3分类为low, medium, high

print(wine.info())
print(wine.describe())
wine = wine.drop_duplicates()

wine['quality'].value_counts().plot(kind = 'pie', autopct = '%.2f')
plt.show()

print(wine.corr().quality)

plt.subplot(121)
sns.barplot(x = 'quality', y = 'volatile acidity', data = wine)
plt.subplot(122)
sns.barplot(x = 'quality', y = 'alcohol', data = wine)
plt.show()

from sklearn.preprocessing import LabelEncoder
bins = (2, 4, 6, 8)  #设置待划分数据的bins,bin划分数据的方式为构成左开右闭区间,所以划分结果为(2,4](4,6](6,8]
group_names  = ['low', 'medium', 'high']  #定义bins划分后的组名
wine['quality_lb'] = pd.cut(wine['quality'], bins = bins, labels = group_names)  #使用pandas模块中的cut()函数将数据分箱

​ 属性中多出quality_lb表示最后质量情况

wine
Out[15]: 
      fixed acidity  volatile acidity     ...      quality  quality_lb
0               7.4             0.700     ...            5      medium
1               7.8             0.880     ...            5      medium
2               7.8             0.760     ...            5      medium
3              11.2             0.280     ...            6      medium
5               7.4             0.660     ...            5      medium
6               7.9             0.600     ...            5      medium
7               7.3             0.650     ...            7        high
8               7.8             0.580     ...            7        high
9               7.5             0.500     ...            5      medium
10              6.7             0.580     ...            5      medium
12              5.6             0.615     ...            5      medium
13              7.8             0.610     ...            5      medium
14              8.9             0.620     ...            5      medium
15              8.9             0.620     ...            5      medium
16              8.5             0.280     ...            7        high
17              8.1             0.560     ...            5      medium
18              7.4             0.590     ...            4         low
19              7.9             0.320     ...            6      medium
20              8.9             0.220     ...            6      medium
21              7.6             0.390     ...            5      medium
22              7.9             0.430     ...            5      medium
23              8.5             0.490     ...            5      medium
24              6.9             0.400     ...            6      medium
25              6.3             0.390     ...            5      medium
26              7.6             0.410     ...            5      medium
28              7.1             0.710     ...            5      medium
29              7.8             0.645     ...            6      medium
30              6.7             0.675     ...            5      medium
31              6.9             0.685     ...            6      medium
32              8.3             0.655     ...            5      medium
            ...               ...     ...          ...         ...
1566            6.7             0.160     ...            6      medium
1568            7.0             0.560     ...            5      medium
1569            6.2             0.510     ...            6      medium
1570            6.4             0.360     ...            6      medium
1571            6.4             0.380     ...            6      medium
1572            7.3             0.690     ...            5      medium
1573            6.0             0.580     ...            6      medium
1574            5.6             0.310     ...            6      medium
1575            7.5             0.520     ...            6      medium
1576            8.0             0.300     ...            6      medium
1577            6.2             0.700     ...            6      medium
1578            6.8             0.670     ...            6      medium
1579            6.2             0.560     ...            5      medium
1580            7.4             0.350     ...            6      medium
1582            6.1             0.715     ...            5      medium
1583            6.2             0.460     ...            5      medium
1584            6.7             0.320     ...            7        high
1585            7.2             0.390     ...            6      medium
1586            7.5             0.310     ...            6      medium
1587            5.8             0.610     ...            6      medium
1588            7.2             0.660     ...            6      medium
1589            6.6             0.725     ...            5      medium
1590            6.3             0.550     ...            6      medium
1591            5.4             0.740     ...            6      medium
1592            6.3             0.510     ...            6      medium
1593            6.8             0.620     ...            6      medium
1594            6.2             0.600     ...            5      medium
1595            5.9             0.550     ...            6      medium
1597            5.9             0.645     ...            5      medium
1598            6.0             0.310     ...            6      medium

[1359 rows x 13 columns]

​ 字符串不方便计算,使用preprocessing模块的LabelEncoder()函数分配标签

lb_quality = LabelEncoder()  #为quality_lb属性分配标签0,1,2对应low,medium,high   
wine['label'] = lb_quality.fit_transform(wine['quality_lb'])   #label属性为具体标签

​ 属性中出现label表示具体质量登记标签

wine
Out[19]: 
      fixed acidity  volatile acidity  ...    quality_lb  label
0               7.4             0.700  ...        medium      2
1               7.8             0.880  ...        medium      2
2               7.8             0.760  ...        medium      2
3              11.2             0.280  ...        medium      2
5               7.4             0.660  ...        medium      2
6               7.9             0.600  ...        medium      2
7               7.3             0.650  ...          high      0
8               7.8             0.580  ...          high      0
9               7.5             0.500  ...        medium      2
10              6.7             0.580  ...        medium      2
12              5.6             0.615  ...        medium      2
13              7.8             0.610  ...        medium      2
14              8.9             0.620  ...        medium      2
15              8.9             0.620  ...        medium      2
16              8.5             0.280  ...          high      0
17              8.1             0.560  ...        medium      2
18              7.4             0.590  ...           low      1
19              7.9             0.320  ...        medium      2
20              8.9             0.220  ...        medium      2
21              7.6             0.390  ...        medium      2
22              7.9             0.430  ...        medium      2
23              8.5             0.490  ...        medium      2
24              6.9             0.400  ...        medium      2
25              6.3             0.390  ...        medium      2
26              7.6             0.410  ...        medium      2
28              7.1             0.710  ...        medium      2
29              7.8             0.645  ...        medium      2
30              6.7             0.675  ...        medium      2
31              6.9             0.685  ...        medium      2
32              8.3             0.655  ...        medium      2
            ...               ...  ...           ...    ...
1566            6.7             0.160  ...        medium      2
1568            7.0             0.560  ...        medium      2
1569            6.2             0.510  ...        medium      2
1570            6.4             0.360  ...        medium      2
1571            6.4             0.380  ...        medium      2
1572            7.3             0.690  ...        medium      2
1573            6.0             0.580  ...        medium      2
1574            5.6             0.310  ...        medium      2
1575            7.5             0.520  ...        medium      2
1576            8.0             0.300  ...        medium      2
1577            6.2             0.700  ...        medium      2
1578            6.8             0.670  ...        medium      2
1579            6.2             0.560  ...        medium      2
1580            7.4             0.350  ...        medium      2
1582            6.1             0.715  ...        medium      2
1583            6.2             0.460  ...        medium      2
1584            6.7             0.320  ...          high      0
1585            7.2             0.390  ...        medium      2
1586            7.5             0.310  ...        medium      2
1587            5.8             0.610  ...        medium      2
1588            7.2             0.660  ...        medium      2
1589            6.6             0.725  ...        medium      2
1590            6.3             0.550  ...        medium      2
1591            5.4             0.740  ...        medium      2
1592            6.3             0.510  ...        medium      2
1593            6.8             0.620  ...        medium      2
1594            6.2             0.600  ...        medium      2
1595            5.9             0.550  ...        medium      2
1597            5.9             0.645  ...        medium      2
1598            6.0             0.310  ...        medium      2

[1359 rows x 14 columns]

​ 用value_counts()方法再统计新类别的分布

wine.label.value_counts()
Out[20]: 
2    1112
0     184
1      63
Name: label, dtype: int64

​ 对数据进行处理

wine_copy = wine.copy()
wine.drop(['quality', 'quality_lb'], axis = 1, inplace = True)  #对wine数据属性进行简化,留下label属性

​ 通过数据选择的方式将特征属性和目标属性分开存入X,y

X = wine.iloc[:,:-1]  #存储特征属性
y = wine.label  #存储目标属性
X
Out[22]: 
      fixed acidity  volatile acidity   ...     sulphates  alcohol
0               7.4             0.700   ...          0.56      9.4
1               7.8             0.880   ...          0.68      9.8
2               7.8             0.760   ...          0.65      9.8
3              11.2             0.280   ...          0.58      9.8
5               7.4             0.660   ...          0.56      9.4
6               7.9             0.600   ...          0.46      9.4
7               7.3             0.650   ...          0.47     10.0
8               7.8             0.580   ...          0.57      9.5
9               7.5             0.500   ...          0.80     10.5
10              6.7             0.580   ...          0.54      9.2
12              5.6             0.615   ...          0.52      9.9
13              7.8             0.610   ...          1.56      9.1
14              8.9             0.620   ...          0.88      9.2
15              8.9             0.620   ...          0.93      9.2
16              8.5             0.280   ...          0.75     10.5
17              8.1             0.560   ...          1.28      9.3
18              7.4             0.590   ...          0.50      9.0
19              7.9             0.320   ...          1.08      9.2
20              8.9             0.220   ...          0.53      9.4
21              7.6             0.390   ...          0.65      9.7
22              7.9             0.430   ...          0.91      9.5
23              8.5             0.490   ...          0.53      9.4
24              6.9             0.400   ...          0.63      9.7
25              6.3             0.390   ...          0.56      9.3
26              7.6             0.410   ...          0.59      9.5
28              7.1             0.710   ...          0.55      9.4
29              7.8             0.645   ...          0.59      9.8
30              6.7             0.675   ...          0.54     10.1
31              6.9             0.685   ...          0.57     10.6
32              8.3             0.655   ...          0.66      9.8
            ...               ...   ...           ...      ...
1566            6.7             0.160   ...          0.71     11.2
1568            7.0             0.560   ...          0.59      9.2
1569            6.2             0.510   ...          0.57     11.5
1570            6.4             0.360   ...          0.93     12.4
1571            6.4             0.380   ...          0.65     11.1
1572            7.3             0.690   ...          0.51      9.5
1573            6.0             0.580   ...          0.67     12.5
1574            5.6             0.310   ...          0.48     10.5
1575            7.5             0.520   ...          0.64     11.8
1576            8.0             0.300   ...          0.78     10.8
1577            6.2             0.700   ...          0.60     11.9
1578            6.8             0.670   ...          0.67     11.3
1579            6.2             0.560   ...          0.60     11.3
1580            7.4             0.350   ...          0.60     11.9
1582            6.1             0.715   ...          0.50     11.9
1583            6.2             0.460   ...          0.62      9.8
1584            6.7             0.320   ...          0.80     11.6
1585            7.2             0.390   ...          0.84     11.5
1586            7.5             0.310   ...          0.85     11.4
1587            5.8             0.610   ...          0.66     10.9
1588            7.2             0.660   ...          0.78     12.8
1589            6.6             0.725   ...          0.54      9.2
1590            6.3             0.550   ...          0.82     11.6
1591            5.4             0.740   ...          0.56     11.6
1592            6.3             0.510   ...          0.75     11.0
1593            6.8             0.620   ...          0.82      9.5
1594            6.2             0.600   ...          0.58     10.5
1595            5.9             0.550   ...          0.76     11.2
1597            5.9             0.645   ...          0.71     10.2
1598            6.0             0.310   ...          0.66     11.0

[1359 rows x 11 columns]

y
Out[23]: 
0       2
1       2
2       2
3       2
5       2
6       2
7       0
8       0
9       2
10      2
12      2
13      2
14      2
15      2
16      0
17      2
18      1
19      2
20      2
21      2
22      2
23      2
24      2
25      2
26      2
28      2
29      2
30      2
31      2
32      2
       ..
1566    2
1568    2
1569    2
1570    2
1571    2
1572    2
1573    2
1574    2
1575    2
1576    2
1577    2
1578    2
1579    2
1580    2
1582    2
1583    2
1584    0
1585    2
1586    2
1587    2
1588    2
1589    2
1590    2
1591    2
1592    2
1593    2
1594    2
1595    2
1597    2
1598    2
Name: label, Length: 1359, dtype: int64

2.5选取训练和测试数据

​ 将数据划分为数据集和训练集,使用train_test_split()函数,该函数可随机地从样本中按比例选取训练数据和测试数据,test_size参数用来设置测试集的比例

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)  #设置测试集比例为20%

2.6标准化处理

​ 数据规范化处理,对特征属性的训练集和测试集用scale()函数进行了标准化处理

from sklearn.preprocessing import scale     
X_train = scale(X_train)
X_test = scale(X_test)

三、机器学习建模

3.1机器学习模型选择

​ 随机森林(Random Forest)是一中机器学习模型,是一种集成学习(Ensemble Learning)算法

​ 集成学习即构建并结合多个学习器来完成学习任务

​ 随机森林模型属于并行式集成学习代表Bagging类型

​ 具体做法为:对原始数据集进行多次随机采样,得到多个不同的采样集,然后基于每个采样集训练一个决策树基学习器,再将这些基学习器进行结合,最终通过投票或取均值等方式使得模型获得较高的精准度和泛化性能

3.2训练模型

​ 利用RandomForestClassifier()函数构建一个分类器,n_estimators参数是指在利用最大投票数或均值来预测前想要建立决策时的子树的数量(因为是基学习器),通常较多的子树,可以让模型有更好的性能

from sklearn.metrics import confusion_matrix

rfc = RandomForestClassifier(n_estimators = 200)  #构建分类器
rfc.fit(X_train, y_train)  #基于训练集进行学习
y_pred = rfc.predict(X_test)  #利用predict()方法,基于测试集的X部分数据(X_test)进行预测
print(confusion_matrix(y_test, y_pred))  #将预测效果与实际的y值比较,用常规的混淆矩阵来观察

3.3预测结果判断

​ 混淆矩阵是一种算法性能的可视化呈现,每一列代表预测值,每一行代表实际的类别

#分类结果的混淆矩阵
[[ 16   0  19]		#16为类别0(high)判断正确的个数,19为本来是类别0但被误判为类别2,类别0的总个数为16+19
 [  0   1  10]
 [  8   1 217]]

​ 对角线上的个数为正确判断出类别的数据记录条数,其他位置为类别误判的条数,对角线上的值占总数越大表示分类效果越好

四、调参

​ 用GridSearchCV去调参需要人工选择的参数成为超参数(随机森林当中决策树的个数即前面的n_estinmators参数对应的值)

GridSearchCV函数实则为暴力搜索,将参数输入就可以给出最优化的结果和参数,适用于小数据集

grid_rfc = GridSearchCV(rfc, param_rfc, iid = False, cv = 5)  #调参
grid_rfc.fit(X_train, y_train)
best_param_rfc = grid_rfc.best_params_
print(best_param_rfc)  #保存取得最佳结果的参数的组合
#基于最佳参数组合重新训练模型,预测结果
rfc = RandomForestClassifier(n_estimators = best_param_rfc['n_estimators'], criterion = best_param_rfc['criterion'], random_state=0)
rfc.fit(X_train, y_train)
y_pred = rfc.predict(X_test)
print(confusion_matrix(y_test, y_pred))
{'criterion': 'gini', 'n_estimators': 60}
[[ 16   0  30]
 [  0   0  13]
 [  6   0 207]]

五、整体代码

# -*- coding: utf-8 -*-
"""
winequality-red data mining

@author: Dazhuang
"""
# url: https://archive.ics.uci.edu/ml/datasets/Wine+Quality
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
import warnings
warnings.filterwarnings('ignore') 

try:        #读取数据
    wine = pd.read_csv('winequality-red.csv', sep = ';')    #将数据存在wine中,wine为DataFrame对象
except:
    print("Cannot find the file!")

print(wine.info())
print(wine.describe())
wine = wine.drop_duplicates()

wine['quality'].value_counts().plot(kind = 'pie', autopct = '%.2f')
plt.show()

print(wine.corr().quality)

plt.subplot(121)
sns.barplot(x = 'quality', y = 'volatile acidity', data = wine)
plt.subplot(122)
sns.barplot(x = 'quality', y = 'alcohol', data = wine)
plt.show()

from sklearn.preprocessing import LabelEncoder
bins = (2, 4, 6, 8)  #设置待划分数据的bins,bin划分数据的方式为构成左开右闭区间,所以划分结果为(2,4](4,6](6,8]
group_names  = ['low', 'medium', 'high']  #定义bins划分后的组名
wine['quality_lb'] = pd.cut(wine['quality'], bins = bins, labels = group_names)  #使用pandas模块中的cut()函数将数据分箱

lb_quality = LabelEncoder()  #字符串不方便计算,为quality_lb属性分配标签0,1,2对应low,medium,high   
wine['label'] = lb_quality.fit_transform(wine['quality_lb'])   #label属性为具体标签

print(wine.label.value_counts())

wine_copy = wine.copy()
wine.drop(['quality', 'quality_lb'], axis = 1, inplace = True)  #对wine数据属性进行简化,留下label属性

X = wine.iloc[:,:-1]  #存储特征属性
y = wine.label  #存储目标属性

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)  #设置测试集比例为20%

from sklearn.preprocessing import scale     
X_train = scale(X_train)
X_test = scale(X_test)

from sklearn.metrics import confusion_matrix

rfc = RandomForestClassifier(n_estimators = 200)  #构建分类器
rfc.fit(X_train, y_train)  #基于训练集进行学习
y_pred = rfc.predict(X_test)  #利用predict()方法,基于测试集的X部分数据(X_test)进行预测
print(confusion_matrix(y_test, y_pred))  #将预测效果与实际的y值比较,用常规的混淆矩阵来观察

param_rfc = {  #选择要调参的参数
            "n_estimators": [10,20,30,40,50,60,70,80,90,100,150,200],
            "criterion": ["gini", "entropy"]
            }
grid_rfc = GridSearchCV(rfc, param_rfc, iid = False, cv = 5)  #调参
grid_rfc.fit(X_train, y_train)
best_param_rfc = grid_rfc.best_params_
print(best_param_rfc)  #保存取得最佳结果的参数的组合
#基于最佳参数组合重新训练模型,预测结果
rfc = RandomForestClassifier(n_estimators = best_param_rfc['n_estimators'], criterion = best_param_rfc['criterion'], random_state=0)
rfc.fit(X_train, y_train)
y_pred = rfc.predict(X_test)
print(confusion_matrix(y_test, y_pred))

六、参考材料

[1]用Python玩转数据

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值