上智能优化课的作业,老师布置的作业。无奈上课没听,不知道怎么做。找了一圈也没找到好用的代码,就一个C++的,用起来还麻烦。于是我索性自己写了一个python版的。
import numpy as np
def func(x, y):
return x ** 2 - 4 * x + y ** 2 - y - x * y
def go(good: np.array, medium:np.array, bad:np.array, f: 'function'):
alpha = good - bad
beta = medium - bad
new = bad + alpha + beta # 先做一次反射
if f(*new) < f(*good): # 反射点比最好点好
new_ex = new + (alpha + beta) / 2 # 扩张反射点
if f(*new_ex) < f(*new): # 扩张点比反射点好
return [good, medium, new_ex]
return [good, medium, new]
elif f(*new) > f(*good): # 反射点比最好点差
new_sr = new - (alpha + beta) / 4 # 收缩扩张点
if f(*new_sr) < f(*bad): # 收缩点1比最坏点好
return [good, medium, new_sr]
bad_sr = bad + (alpha + beta) / 4 # 最坏点收缩
if f(*bad_sr) < f(*bad): # 收缩点2比最坏点好
return [good, medium, bad_sr]
good2 = bad - alpha / 2 # 压缩
medium2 = medium - beta / 2
return [good2, medium2, bad]
else:
return [good, medium, new]
def simplex(x0, f, d, n=100):
'''
x0: 初试点
f: 目标函数
d: 精度(没用上)
n: 最大迭代次数
'''
sol = [np.array(x) for x in x0]
pre_val = 1e10
while True:
sol = sorted(sol, key=lambda x: f(*x))
if n < 0: # 防止循环次数过多
return
n -= 1
val = f(*sol[0])
# if abs(val - pre_val) < d: # 本来想用差值小于定值作为循环结束条件
# return # 结果可能是我逻辑有问题,有时候一次迭代后函数值不变
pre_val = val
print('solution: ', [list(np.round(x, 5)) for x in sol], ', value: ', [f(*x) for x in sol], ', best: ', val)
sol = go(*sol, f)
if __name__ == '__main__':
simplex([[0, 0], [1.2, 0], [0, 0.8]], func, 1e-5, 24)
如果代码有问题欢迎大佬指正,我感觉可能有点小问题,不过可以正常求解。
solution: [[1.2, 0.0], [0.0, 0.8], [0, 0]] , value: [-3.36, -0.15999999999999992, 0] , best: -3.36
solution: [[1.8, 1.2], [1.2, 0.0], [0.0, 0.8]] , value: [-5.88, -3.36, -0.15999999999999992] , best: -5.88
solution: [[1.8, 1.2], [2.25, 0.5], [1.2, 0.0]] , value: [-5.88, -5.3125, -3.36] , best: -5.88
solution: [[2.85, 1.7], [1.8, 1.2], [2.25, 0.5]] , value: [-6.932499999999999, -5.88, -5.3125] , best: -6.932499999999999
solution: [[2.85, 1.7], [2.3625, 1.925], [1.8, 1.2]] , value: [-6.932499999999999, -6.63578125, -5.88] , best: -6.932499999999999
solution: [[3.00937, 2.11875], [2.85, 1.7], [2.3625, 1.925]] , value: [-6.986923828125, -6.932499999999999, -6.63578125] , best: -6.986923828125