前言
上一章,我们讲了SVM手写数字识别 - SKLearn实现SVM - MNIST数据集,其中使用的数据集是sklearn中自带的mnist数据集。现在,我们要在上一个代码的基础上,将数据集更换为medical-mnist,并且进行分类与预测。
数据集准备
首先,数据集长这个样子(网上随便找的),四个文件夹分别代表四个类别,每个文件夹内都是1000张左右的.jpeg后缀的图片文件
打开pycharm,把数据集放进去。顺便再创建一个test文件夹,里面放一张图,供一会儿测试使用
开始敲代码
先把要导的包放在最前面
# sklearn
import joblib
import sklearn.model_selection as sk_model_selection
import sklearn.preprocessing as sk_preprocessing
from sklearn.model_selection import GridSearchCV
from sklearn import svm
# 图像处理opencv
import os
import cv2
一、加载数据集
# 数据集根目录
dataset_root = 'data/medical-mnist'
# 类别名称列表
class_names = ['AbdomenCT', 'BreastMRI', 'ChestCT', 'CXR']
# 存储图像数据和标签的列表
images = []
labels = []
# 遍历每个类别文件夹
for class_name in class_names:
# 类别文件夹路径
class_dir = os.path.join(dataset_root, class_name)
# 获取类别的标签
label = class_names.index(class_name)
# 遍历当前类别文件夹下的图像文件
for filename in os.listdir(class_dir):
if filename.endswith('.jpeg'):
# 图像文件路径
image_path = os.path.join(class_dir, filename)
# 使用OpenCV加载图像
image = cv2.imread(image_path)
# 将图像转换为灰度图或其他需要的格式
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 将图像数据添加到列表
images.append(gray_image.reshape(-1))
# 将对应的标签添加到列表
labels.append(label)
这时我们print一下,看看效果咋样
# 打印数据集大小和标签示例
print('数据集大小:', len(images))
print('标签示例:', labels[999:1002])
二、处理数据集
- X和y:输入的数据集和标签。
- train_size=0.8:指定训练集的比例,这里设置为80%。
- random_state=20:指定随机种子,确保划分的结果可以复现。
scaler = sk_preprocessing.StandardScaler().fit(X_train)
:使用StandardScaler
类对训练集数据进行拟合,计算出均值和标准差。X_train_scaled = scaler.transform(X_train)
:对训练集数据进行标准化处理,将数据按照均值和标准差进行缩放转换。
# 加载新的数据集
X = images
y = labels
# 数据集划分
X_train, X_test, y_train, y_test = sk_model_selection.train_test_split(X, y, train_size=0.8, random_state=20)
# 数据预处理
scaler = sk_preprocessing.StandardScaler().fit(X_train)
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)
三、模型训练与评估
-
创建SVM模型:使用SVM(支持向量机)算法创建一个分类模型。参数C=1.0表示正则化参数为1.0(给定初始值可以帮助缩小搜索空间并加速搜索过程),kernel='rbf'表示使用径向基函数作为核函数,gamma='auto'表示使用默认的gamma值。
-
拟合模型:使用训练集(X_train_scaled和y_train)对模型进行拟合,通过学习训练数据集的特征和标签之间的关系,训练模型的参数。
-
为模型打分:使用测试集(X_test_scaled和y_test)对训练好的模型进行评估,计算模型在测试数据上的准确率。
-
网格搜索优化模型:使用网格搜索(GridSearchCV)来搜索SVM模型的最佳参数组合。通过指定不同的参数值,如kernel和C的取值范围,GridSearchCV会自动进行交叉验证来确定最佳参数组合。在给定的参数范围内,网格搜索会尝试不同的参数组合并计算每个组合的得分,最后找到最佳得分对应的最佳参数组合。
-
输出最优参数和最优得分:打印出网格搜索找到的最佳参数组合和对应的最佳得分。
-
交叉验证:使用交叉验证(cross_val_score)对最优模型进行评估。将数据集分成K个子集,依次使用其中一个子集作为验证集,其余子集作为训练集,计算模型在验证集上的准确率。通过交叉验证可以更准确地评估模型的性能。
# 创建SVM模型
model = svm.SVC(C=1.0, kernel='rbf', gamma='auto')
# 拟合模型
model.fit(X_train_scaled, y_train)
# 为模型打分
accuracy = model.score(X_test_scaled, y_test)
print('SVM模型评价:', accuracy)
# 网格搜索优化模型
parameters = {'kernel': ('linear', 'rbf'), 'C': [1, 10]}
grid_model = GridSearchCV(model, parameters)
grid_model.fit(X_train_scaled, y_train)
print("最优参数: ", grid_model.best_params_)
print("最优得分: ", grid_model.best_score_)
# 交叉验证
scores = sk_model_selection.cross_val_score(grid_model, X, y, cv=5)
print("交叉验证得分: ", scores)
print("平均准确率: ", scores.mean())
四、模型保存与加载
- 将
grid_model
(经过网格搜索优化后的模型)保存到名为'svm_medical_model.pkl'
的文件中。joblib.dump()
函数用于将对象序列化并保存到文件中,以便后续使用。 - 从文件
'svm_medical_model.pkl'
中加载模型,并将其存储在loaded_model
变量中。joblib.load()
函数用于从文件中加载先前保存的模型。
# 模型保存
joblib.dump(grid_model, 'svm_medical_model.pkl')
# 模型加载
loaded_model = joblib.load('svm_medical_model.pkl')
五、模型预测
image_path = "data/test/test.jpeg"
: 定义了图像文件的路径- 使用
scaler
对象的transform()
方法对test
进行特征缩放,以使其具有与训练数据相似的数据分布。然后,使用reshape()
函数将test
转换为形状为(1, -1)
的二维数组,以匹配模型期望的输入形状。 - 使用加载的模型
loaded_model
对预处理后的图像test
进行预测。
# 图像文件路径
image_path = "data/test/test.jpeg"
# 使用OpenCV加载图像
image = cv2.imread(image_path)
# 将图像转换为灰度图或其他需要的格式
test = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 对图像进行与训练数据相同的预处理
test = scaler.transform(gray_image.reshape(1, -1))
# 自己选择的图像 预测结果
predictions = loaded_model.predict(test)
print(predictions)
完整代码
# sklearn
import joblib
import sklearn.model_selection as sk_model_selection
import sklearn.preprocessing as sk_preprocessing
from sklearn.model_selection import GridSearchCV
from sklearn import svm
# 图像处理opencv
import os
import cv2
# 数据集根目录
dataset_root = 'data/medical-mnist'
# 类别名称列表
class_names = ['AbdomenCT', 'BreastMRI', 'ChestCT', 'CXR']
# 存储图像数据和标签的列表
images = []
labels = []
# 遍历每个类别文件夹
for class_name in class_names:
# 类别文件夹路径
class_dir = os.path.join(dataset_root, class_name)
print(class_name)
print(class_names)
# 获取类别的标签
label = class_names.index(class_name)
print(label)
# 遍历当前类别文件夹下的图像文件
for filename in os.listdir(class_dir):
if filename.endswith('.jpeg'):
# 图像文件路径
image_path = os.path.join(class_dir, filename)
# 使用OpenCV加载图像
image = cv2.imread(image_path)
# 将图像转换为灰度图或其他需要的格式
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 将图像数据添加到列表
images.append(gray_image.reshape(-1))
# 将对应的标签添加到列表
labels.append(label)
# 打印数据集大小和标签示例
print('数据集大小:', len(images))
print('标签示例:', labels[999:1002])
# 加载新的数据集
X = images
y = labels
# 数据集划分
X_train, X_test, y_train, y_test = sk_model_selection.train_test_split(X, y, train_size=0.8, random_state=20)
# 数据预处理
scaler = sk_preprocessing.StandardScaler().fit(X_train)
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 创建SVM模型
model = svm.SVC(C=1.0, kernel='rbf', gamma='auto')
# 拟合模型
model.fit(X_train_scaled, y_train)
# 为模型打分
accuracy = model.score(X_test_scaled, y_test)
print('SVM模型评价:', accuracy)
# 网格搜索优化模型
parameters = {'kernel': ('linear', 'rbf'), 'C': [1, 10]}
grid_model = GridSearchCV(model, parameters)
grid_model.fit(X_train_scaled, y_train)
print("最优参数: ", grid_model.best_params_)
print("最优得分: ", grid_model.best_score_)
# 交叉验证
scores = sk_model_selection.cross_val_score(grid_model, X, y, cv=5)
print("交叉验证得分: ", scores)
print("平均准确率: ", scores.mean())
# 模型保存
joblib.dump(grid_model, 'svm_medical_model.pkl')
# 模型加载
loaded_model = joblib.load('svm_medical_model.pkl')
# 测试集 预测结果
predictions = loaded_model.predict(X_test_scaled)
print(predictions)
# 图像文件路径
image_path = "data/test/test.jpeg"
# 使用OpenCV加载图像
image = cv2.imread(image_path)
# 将图像转换为灰度图或其他需要的格式
test = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 对图像进行与训练数据相同的预处理
test = scaler.transform(gray_image.reshape(1, -1))
# 自己选择的图像 预测结果
predictions = loaded_model.predict(test)
print(predictions)