import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
def F(X):
F = 3 * (1 - X[0]) ** 2 * np.exp(-(X[0] ** 2) - (X[1] + 1) ** 2) - 10 * (
X[0] / 5 - X[0] ** 3 - X[1] ** 5) * np.exp(
-X[0] ** 2 - X[1] ** 2) - 1 / 3 ** np.exp(-(X[0] + 1) ** 2 - X[1] ** 2)
return F
def fitness(X):
fit_value = F(X) # 求最大
#fit_value = -F(X) # 求最小
return fit_value
def plot_3d(ax):
X = np.linspace(-3, 3, 100)
Y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(X, Y)
Z = F([X, Y])
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm)
ax.set_zlim(-10, 10)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.pause(3)
plt.show()
def main():
fig1 = plt.figure()
ax = Axes3D(fig1)
plt.ion() # 将画图模式改为交互模式,程序遇到plt.show不会暂停,而是继续执行
plot_3d(ax)
# 参数初始化
p_num = 40 # 粒子数
max_iter = 100
w = 0.8 # 惯性权重
c1 = 2 # 局部学习因子
c2 = 2 # 全局学习因子
dim = 2 # 搜索维度
X = np.zeros((p_num, dim)) # 位置
V = np.zeros((p_num, dim)) # 速度
p = np.zeros((p_num, dim)) # 个体最优(位置)
p_best = np.zeros((1, dim)) # 全局最优(位置)
f = np.zeros(p_num) # 个体最优
f_best = -np.inf # 全局最优
V_max = 0.2
# 初始化
for i in range(p_num):
for j in range(dim):
X[i][j] = np.random.random(1) * 6 - 3
V[i][j] = np.random.random(1) * V_max * 2 - V_max
p[i] = X[i]
temp = fitness(X[i]) # 个体最优为初始位置
f[i] = temp
if temp > f_best: # 全局最优为适应度最大的个体的位置
f_best = temp
p_best = X[i]
# 进入粒子群迭代
for t in range(max_iter):
F_value = [] # 储存目标函数值,用于可视化
for i in range(p_num):
r1 = np.random.random(1)
r2 = np.random.random(1)
# 更新速度
V[i] = w * V[i] + c1 * r1 * (p[i] - X[i]) + c2 * r2 * (p_best - X[i])
# 更新位置
X[i] = X[i] + V[i]
# 速度限制
V[i][V[i] > V_max] = V_max
V[i][V[i] < -V_max] = -V_max
# 位置限制
X[i][X[i] > 3] = 3
X[i][X[i] < -3] = -3
F_value.append(F(X[i]))
# 更新个体最优和全局最优
for i in range(p_num):
temp = fitness(X[i])
if temp > f[i]: # 更新个体最优
f[i] = temp
p[i] = X[i]
if temp > f_best: # 更新全局最优
p_best = X[i]
f_best = temp
if 'sca' in locals():
sca.remove() # 去除图像中上一个种群的点
sca = ax.scatter(X[:, 0], X[:, 1], F_value, c='black', marker='o')
plt.show()
plt.pause(0.1)
plt.ioff()
plot_3d(ax)
plt.show()