sklearn.svm
模块
sklearn.svm
提供了 支持向量机(Support Vector Machine, SVM) 相关模型,适用于 分类、回归和异常检测 任务。
1. sklearn.svm
主要模型
任务 | 模型 | 适用情况 |
---|---|---|
分类 | SVC | 标准 SVM 分类器(适用于小数据集) |
分类 | NuSVC | 类似 SVC ,但用 nu 控制支持向量比例 |
分类 | LinearSVC | 线性 SVM,适用于高维数据(稀疏数据) |
回归 | SVR | 标准 SVM 回归器(适用于小数据集) |
回归 | NuSVR | 类似 SVR ,但用 nu 控制误差上界 |
回归 | LinearSVR | 线性 SVM 回归器,适用于高维数据 |
异常检测 | OneClassSVM | 用于异常检测(如欺诈检测) |
2. SVM 分类 (SVC
)
(1) 训练 SVM 分类器
from sklearn.svm import SVC
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 加载数据
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# 训练 SVM 分类器
model = SVC(kernel="rbf", C=1.0, gamma="scale")
model.fit(X_train, y_train)
# 计算准确率
accuracy = model.score(X_test, y_test)
print("SVM 分类器准确率:", accuracy)
解释
kernel="rbf"
:默认使用 径向基函数(RBF)核,适用于非线性数据。C=1.0
:正则化参数,值越小,越允许误分类(避免过拟合)。gamma="scale"
:核函数的参数,值越大,影响越局部。
3. SVC
主要参数
SVC(kernel="rbf", C=1.0, gamma="scale", degree=3, probability=False, random_state=None)
参数 | 说明 |
---|---|
kernel | 核函数类型("linear" 、"poly" 、"rbf" 、"sigmoid" ) |
C | 正则化参数(默认 1.0 ,值大则更关注正确分类) |
gamma | 核函数参数(默认 "scale" ,值越大影响越局部) |
degree | 多项式核的阶数(仅在 kernel="poly" 时有效) |
probability | 是否计算类别概率(默认 False ,计算代价较高) |
4. 线性 SVM (LinearSVC
)
(2) 训练线性 SVM
from sklearn.svm import LinearSVC
model = LinearSVC(C=1.0, max_iter=1000)
model.fit(X_train, y_train)
accuracy = model.score(X_test, y_test)
print("线性 SVM 分类器准确率:", accuracy)
解释
- 适用于高维数据(如文本分类)。
- 比
SVC(kernel="linear")
速度更快,但不支持probability=True
。
5. SVM 回归 (SVR
)
(3) 训练 SVM 回归器
from sklearn.svm import SVR
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
# 生成回归数据
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)
# 训练 SVM 回归模型
model = SVR(kernel="rbf", C=1.0, epsilon=0.1)
model.fit(X, y)
# 预测
y_pred = model.predict(X)
# 计算 R²
r2 = model.score(X, y)
print("SVM 回归器 R²:", r2)
解释
epsilon=0.1
:允许的误差范围,值越大,模型越宽松。- 适用于小数据集的非线性回归问题。
6. SVM 异常检测 (OneClassSVM
)
(4) 训练 OneClassSVM
进行异常检测
from sklearn.svm import OneClassSVM
import numpy as np
# 生成数据
rng = np.random.RandomState(42)
X = 0.3 * rng.randn(100, 2) # 100 个正常点
X = np.concatenate([X, rng.uniform(low=-4, high=4, size=(10, 2))]) # 10 个异常点
# 训练 One-Class SVM 模型
model = OneClassSVM(kernel="rbf", gamma="scale", nu=0.1)
model.fit(X)
# 预测:1 表示正常点,-1 表示异常点
y_pred = model.predict(X)
# 统计异常点数量
n_outliers = np.sum(y_pred == -1)
print("检测到的异常点数量:", n_outliers)
解释
nu=0.1
:表示假设10%
的数据是异常点(超参数)。- 适用于无标签数据的异常检测。
7. sklearn.svm
各模型对比
任务 | 模型 | 适用情况 | 主要区别 |
---|---|---|---|
分类 | SVC | 非线性分类(小数据集) | 支持核函数(默认 RBF),计算较慢 |
分类 | NuSVC | 类似 SVC | 用 nu 控制支持向量比例 |
分类 | LinearSVC | 线性分类(高维数据) | 速度快,适用于文本分类 |
回归 | SVR | 非线性回归(小数据集) | 支持核函数(默认 RBF),计算较慢 |
回归 | NuSVR | 类似 SVR | 用 nu 控制误差上界 |
回归 | LinearSVR | 线性回归(高维数据) | 速度快,适用于文本数据 |
异常检测 | OneClassSVM | 无监督异常检测 | 适用于异常检测、欺诈检测 |
8. 适用场景
- 分类任务(如 垃圾邮件检测、文本分类)。
- 回归任务(如 能源消耗预测、股票趋势分析)。
- 异常检测任务(如 欺诈检测、入侵检测)。
9. 结论
sklearn.svm
提供了 SVM 相关的分类、回归和异常检测模型,支持 非线性映射(核函数)、高维数据处理、异常检测,适用于 小规模数据集的复杂问题。