简介:
scipy.optimize.leastsq对多变量和多参数曲线拟合,多变量指的是多个输入变量,例如x1、x2…而不是单个变量x。多参数指的是需要求解的函数含有多个未知参数,例如k1、k2…。如下,我要拟合的曲线是:
y
=
k
1
x
1
+
k
2
x
2
+
k
3
y=k_1x_1+k_2x_2+k_3
y=k1x1+k2x2+k3
代码:
from typing import List
from matplotlib import pyplot as plt
import numpy as np
from numpy import polyfit, poly1d
from scipy import linalg
from scipy.optimize import leastsq
def predesign_fun(data, arg):
"""
预先设计的函数
:return:
"""
k1, k2, k3 = arg
return k1 * data[0] + k2 * data[1] + k3
def error(arg, y_value, x_data):
"""
误差计算
:return:
"""
return y_value - predesign_fun(x_data, arg)
def get_data():
"""
获取数据
:return:
"""
data = [[50, 48, 242],
[70, 28, 198],
[40, 50, 202],
[80, 20, 162],
[100, 24, 242],
[30, 57, 173],
[60, 24, 145],
[47, 28, 133],
[65, 24, 158],
[71, 29, 208],
[60, 5, 30],
[80, 80, 646],
[60, 40, 242],
[47, 58, 275],
[46, 58, 269],
[60, 30, 182],
[65, 25, 164],
[52, 25, 131],
[54, 68, 371]]
return data
def curve(data):
"""
曲线拟合
:return:
"""
x, y, z = [], [], []
for row in data:
x.append(row[0])
y.append(row[1])
z.append(row[2])
# coeff = polyfit()
x_data = (np.array(x), np.array(y))
print(x)
print(y)
print(z)
arg0 = (10, 10, 10)
result = leastsq(error, arg0, args=(z, x_data))
return result[0]
def paint_3D_line(data):
"""
绘制三维图示
:param data:
:return:
"""
# 定义图像和三维格式坐标轴
x, y, z = [], [], []
for row in data:
x.append(row[0])
y.append(row[1])
z.append(row[2])
ax1 = plt.axes(projection='3d')
# 下面 zd xd yd是三个坐标轴
xd = np.array(x)
yd = np.array(y)
zd = np.array(z)
# 绘制散点图,属性基本还是matplotlib画图的属性设置,
ax1.scatter3D(xd, yd, zd, c='red')
plt.show()
def cal_absolute_error(coe, x_data, y_data):
"""
测试结果
:return:
"""
return coe[0] * x_data[0] + coe[1] * x_data[1] + coe[2] - y_data
def cal_relatively_error(coe, x_data, y_data):
"""
测试结果
:return:
"""
return (coe[0] * x_data[0] + coe[1] * x_data[1] + coe[2] - y_data) / y_data
if __name__ == '__main__':
coef = curve(data=get_data())
print(coef)
paint_3D_line(data=get_data())
error_value = cal_absolute_error(coef, [50, 10], 50)
print(error_value)
error_value = cal_relatively_error(coef, [50, 10], 50)
print(error_value)