sklearn.datasets.make_classification()
make_classification()
是 sklearn.datasets
提供的 合成分类数据集生成工具,用于 创建模拟的分类数据,适用于 机器学习模型测试、特征工程、算法研究。
1. make_classification()
作用
- 生成可控的分类数据集(指定特征数、类别数、冗余度等)。
- 测试机器学习模型(如
KNN
、SVM
、随机森林
)。 - 研究特征选择、数据分布、类别不均衡问题。
2. make_classification()
代码示例
(1) 生成二分类数据
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
# 生成数据(1000 个样本,2 个特征,2 个类别)
X, y = make_classification(n_samples=1000, n_features=2, n_classes=2, n_informative=2, n_redundant=0, random_state=42)
# 可视化数据
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm")
plt.title("合成二分类数据")
plt.show()
解释
n_samples=1000
:生成1000
个样本。n_features=2
:每个样本2
个特征(便于可视化)。n_classes=2
:二分类任务(0/1
)。n_informative=2
:2
个有效特征。n_redundant=0
:0
个冗余特征(完全依赖其他特征)。
(2) 生成多分类数据
X, y = make_classification(n_samples=500, n_features=3, n_classes=3, n_clusters_per_class=1, random_state=42)
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(projection="3d")
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap="viridis")
plt.title("合成多分类数据")
plt.show()
解释
n_classes=3
:三分类任务。n_clusters_per_class=1
:每个类别只有1
个簇(默认是2
)。
3. make_classification()
的参数
make_classification(n_samples=100, n_features=20, n_informative=2, n_redundant=2, n_repeated=0,
n_classes=2, n_clusters_per_class=2, weights=None, flip_y=0.01, class_sep=1.0, random_state=None)
参数 | 说明 |
---|---|
n_samples | 样本数(默认 100 ) |
n_features | 总特征数(默认 20 ) |
n_informative | 有效特征数(影响类别) |
n_redundant | 冗余特征数(由有效特征线性组合生成) |
n_repeated | 重复特征数(从其他特征复制) |
n_classes | 类别数(默认 2 ,二分类) |
n_clusters_per_class | 每个类别的簇数(默认 2 ) |
weights | 类别比例(如 weights=[0.1, 0.9] ) |
flip_y | 随机噪声比例(默认 0.01 ,增加数据噪声) |
class_sep | 类别间隔(值越大,类别越容易分开) |
random_state | 随机种子,保证结果可复现 |
4. 生成类别不均衡数据
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.9, 0.1], random_state=42)
print("类别 0 样本数:", sum(y == 0))
print("类别 1 样本数:", sum(y == 1))
输出
类别 0 样本数: 900
类别 1 样本数: 100
解释
weights=[0.9, 0.1]
生成 类别不均衡数据,90%
样本属于类别 0
。
5. make_classification()
vs. 其他数据集生成函数
方法 | 适用情况 | 作用 |
---|---|---|
make_classification() | 分类任务 | 生成 线性可分或复杂分类数据 |
make_regression() | 回归任务 | 生成回归数据 |
make_blobs() | 聚类任务 | 生成高斯分布的聚类数据 |
make_moons() | 非线性分类 | 生成新月形数据 |
make_circles() | 非线性分类 | 生成环形数据 |
示例:
from sklearn.datasets import make_moons
X, y = make_moons(n_samples=500, noise=0.1, random_state=42)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm")
plt.title("新月形数据")
plt.show()
解释
make_moons()
生成非线性分类数据,适用于 SVM、神经网络测试。
6. 适用场景
- 模拟分类任务(测试分类算法)。
- 研究特征工程(观察冗余特征、有效特征的影响)。
- 生成类别不均衡数据(验证
SMOTE
、ADASYN
等方法)。
7. 结论
make_classification()
用于生成合成分类数据,支持 调整特征数量、类别比例、数据噪声,适用于 机器学习研究和算法测试。