Task06 模拟离散数据集的贝叶斯分类算法实践
一、学习内容概括
学习资料:
1.阿里云天池-AI训练营机器学习:https://tianchi.aliyun.com/specials/promotion/aicampml?invite_channel=1&accounttraceid=7df048c2ce194081b514fd2c8e9a3f00cqmm
2.sklearnAPI:https://scikit-learn.org/stable/modules/classes.html
3.pythonAPI:https://docs.python.org/3/c-api/index.html
二、具体学习内容
代码流程:
- Step1: 库函数导入
- Step2: 数据导入&分析
- Step3: 模型训练&评估
- Step4: 模型预测
1 库函数导入
import random
import numpy as np
# 使用基于类目特征的朴素贝叶斯
from sklearn.naive_bayes import CategoricalNB
from sklearn.model_selection import train_test_split
1.1 python random模块:生成伪随机数。https://docs.python.org/zh-cn/3.7/library/random.html?highlight=random#module-random
1.2 sklearn.naive_bayes.CategoricalNB:用于分类特征的朴素贝叶斯。https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.CategoricalNB.html#sklearn.naive_bayes.CategoricalNB
2 数据导入&分析
# 模拟数据
## class numpy.random.RandomState(seed = None)seed用于初始化伪随机数生成器或实例化BitGenerator的随机种子。如果seed是整数或数组,用作MT19937 BitGenerator的种子。
rng = np.random.RandomState(1) #rng:RandomState(MT19937) at 0x7F9433213CA8
# 随机生成600个100维的数据,每一维的特征都是[0, 4]之前的整数
## numpy.random.RandomState.randint(low[, high, size, dtype]),从[low,high)中返回随机整数,若high为None,则[0,low)。size:输出的形状
X = rng.randint(7, size=(600, 100)) ## X.shape(600,100)
y = np.array([1, 2, 3, 4, 5, 6] * 100) ## y.shape(600,)
data = np.c_[X, y] ## data.shape(600,101) :把y变成向量并在x最后作为最后一列
# X和y进行整体打散
## python random.shuffle(x[, random]),将序列 x随机打乱位置。
random.shuffle(data)
X = data[:,:-1]
y = data[:, -1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
numpy.random:随机抽样。https://numpy.org/doc/stable/reference/random/index.html?highlight=random#module-numpy.random
numpy.random.RandomState:传统随机生成。https://numpy.org/doc/stable/reference/random/legacy.html#legacy
numpy.random.RandomState.randint(low[, high, size, dtype]):从[low,high)返回随机整数。https://numpy.org/doc/stable/reference/random/generated/numpy.random.RandomState.randint.html#numpy.random.RandomState.randint
python random.shuffle(x[, random]):将序列 x 随机打乱位置。https://docs.python.org/zh-cn/3.7/library/random.html?highlight=random#module-random
# 查看数据
print(X,X.shape)
print(y,y.shape)
print(data,data.shape)
运行结果:
[[5 3 4 ... 6 0 0]
[1 1 5 ... 0 3 0]
[5 3 4 ... 6 0 0]
...
[4 3 6 ... 6 0 4]
[5 6 4 ... 6 4 0]
[4 2 6 ... 3 4 5]] (600, 100)
[1 2 1 2 5 5 3 1 1 1 2 3 5 4 3 2 2 4 1 6 2 6 4 2 3 2 1 4 4 1 6 3 3 4 1 5 5
2 4 4 3 6 3 1 2 6 5 3 2 6 2 2 6 3 4 1 3 3 6 5 5 2 3 5 2 1 1 1 1 2 3 1 5 1
1 2 5 6 3 4 2 5 4 4 2 6 5 2 1 6 2 3 6 3 4 1 1 2 2 1 4 6 1 2 3 2 2 2 1 5 2
3 1 4 6 4 5 6 3 2 1 5 3 5 1 1 3 3 4 4 4 5 4 5 5 3 4 3 4 6 1 3 5 2 3 2 6 6
4 3 2 1 5 3 2 4 2 6 3 4 5 3 3 6 1 6 4 4 1 2 3 2 6 5 3 1 4 2 1 4 2 4 5 4 6
2 4 5 4 2 3 6 4 1 2 4 1 6 5 4 1 6 5 5 4 1 1 4 6 6 1 3 6 5 2 4 5 2 3 2 4 6
1 5 6 5 4 5 5 3 6 5 4 4 2 6 4 6 4 6 2 6 1 1 1 1 6 2 4 1 1 1 1 3 3 3 6 1 3
6 1 3 3 3 5 1 2 6 4 4 3 2 6 4 3 4 3 2 4 4 5 2 3 6 4 2 4 2 1 4 4 6 2 2 3 5
1 4 1 3 5 1 2 5 4 2 4 3 4 1 4 5 4 2 6 1 5 6 4 2 1 2 6 2 2 4 5 4 4 3 3 6 5
4 3 6 3 5 1 2 4 2 3 2 2 1 3 5 2 1 4 2 5 2 1 1 2 2 2 2 2 1 5 4 1 6 5 4 3 2
2 2 2 2 5 5 2 6 2 5 3 3 2 2 6 4 5 1 4 5 3 2 3 1 5 5 1 2 5 1 5 6 4 2 4 1 1
2 1 4 2 4 1 2 3 5 6 6 5 2 6 3 2 1 4 6 6 1 5 5 6 2 3 5 1 2 2 2 2 4 6 2 4 1
4 3 2 6 5 3 5 2 6 1 4 4 1 4 5 5 6 6 3 5 5 3 4 6 4 4 2 4 5 1 5 1 5 3 6 5 6
5 4 6 5 6 3 2 4 4 4 6 1 3 6 1 2 3 2 4 6 3 5 6 2 2 4 4 1 5 3 5 2 1 6 1 4 6
2 4 3 2 2 1 6 2 2 3 2 6 2 2 4 3 6 3 4 5 2 3 1 2 6 1 3 6 2 1 4 4 2 6 4 2 2
3 1 3 6 5 3 3 5 6 2 5 5 6 1 2 4 1 6 6 3 3 1 5 6 2 5 3 3 1 5 6 1 2 6 1 3 3
4 2 2 6 5 2 6 3] (600,)
[[5 3 4 ... 0 0 1]
[1 1 5 ... 3 0 2]
[5 3 4 ... 0 0 1]
...
[4 3 6 ... 0 4 2]
[5 6 4 ... 4 0 6]
[4 2 6 ... 4 5 3]] (600, 101)
3 模型训练&评估
clf = CategoricalNB(alpha=1) ## alpha平滑参数,默认1.0
clf.fit(X_train, y_train)
acc = clf.score(X_test, y_test) ## score(X,y [,sample_weight])返回给定测试数据和标签上的平均准确度。
print("Test Acc : %.3f" % acc)
运行结果:
Test Acc : 0.625
4 模型预测
# 随机数据测试,分析预测结果,贝叶斯会选择概率最大的预测结果。
# 比如这里的预测结果是5,5对应的概率最大。
# 由于我们是随机数据,读者运行的时候,可能会出现不一样的结果。
x = rng.randint(7, size=(1, 100))
print("预测样本:",x)
print(clf.predict_proba(x))
print(clf.predict(x))
运行结果:
预测样本: [[3 5 6 1 5 6 2 2 0 3 0 0 2 6 1 5 6 4 1 4 1 0 4 6 2 5 4 4 1 3 2 3 5 6 5 3
3 3 4 2 4 0 1 2 4 5 0 5 3 1 0 5 4 1 6 1 5 1 3 4 2 0 0 0 2 4 1 0 3 0 2 4
4 3 0 1 1 4 3 2 1 1 4 1 4 1 5 4 1 4 6 2 1 3 5 5 1 4 3 3]]
[[6.12329417e-02 4.58848710e-03 6.70379604e-03 5.76919007e-03
9.21642304e-01 6.32806677e-05]]
[5]