1.这里适应度函数选取的是测试集准确率
没有优化前的SVM分类效果是95.45%,优化后的效果是98.67% ,适应度曲线如下
2.代码(可以直接运行)
# -*- coding: utf-8 -*-
"""
Created on Sat May 27 08:10:09 2023
"""
import numpy as np
from sklearn import svm
from sklearn.model_selection import cross_val_score
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
# 定义烟花算法的适应度函数(用于评估参数组合的性能)
def fitness_function(params):
C, log_gamma = params
gamma = 10.0 ** log_gamma
clf = svm.SVC(C=C, gamma=gamma)
scores = cross_val_score(clf, X, y, cv=5)
accuracy = np.mean(scores)
error = 1 - accuracy # 计算误差
return accuracy
# 定义烟花算法的粒子类
class Particle:
def __init__(self, position):
self.position = position
self.velocity = np.zeros_like(position)
self.best_position = np.copy(position)
self.best_fitness = fitness_function(position)
# 初始化数据集(这里使用的是Iris数据集)
iris = load_iris()
X = iris.data
y = iris.target
# 定义参数搜索空间
param_space = [
np.logspace(1, 3, num=100), # C参数范围
np.linspace(1, 3, num=100) # log(gamma)参数范围
]
# 初始化烟花算法的粒子群
population_size = 20
particles = []
for _ in range(population_size):
params = [np.random.choice(p) for p in param_space]
particle = Particle(params)
particles.append(particle)
# 迭代优化过程
max_iterations = 50
best_fitnesses = []
for iteration in range(max_iterations):
for particle in particles:
# 更新粒子的速度和位置
acceleration = np.random.uniform(-1, 1, size=len(param_space))
particle.velocity = 0.5 * particle.velocity + acceleration
particle.position += particle.velocity
# 评估粒子的适应度并更新最佳位置
fitness = fitness_function(particle.position)
if fitness > particle.best_fitness:
particle.best_position = np.copy(particle.position)
particle.best_fitness = fitness
# 记录当前迭代的最佳适应度
best_particle = max(particles, key=lambda p: p.best_fitness)
best_fitnesses.append(best_particle.best_fitness)
# 打印当前迭代的结果
print("Iteration:", iteration + 1)
print("Best Fitness:", best_particle.best_fitness)
print("Best Parameters:", best_particle.best_position)
print()
# 输出最佳参数和性能
best_particle = max(particles, key=lambda p: p.best_fitness)
best_params = best_particle.best_position
best_accuracy = best_particle.best_fitness
print("Best Parameters:", best_params)
print("Best Accuracy:", best_accuracy)
# 绘制迭代曲线
plt.plot(range(1, max_iterations + 1), best_fitnesses)
plt.xlabel("Iteration")
plt.ylabel("Best Fitness")
plt.title("Iteration Curve")
plt.show()
3.若对其它优化算法感兴趣,请关注我的优化算法专栏