#pragma once
#include <vector>
#include <functional>
/*
埃尔米特插值
*/
struct InterpolationPoint {
double x; // 插值点的横坐标
double y; // 插值点的纵坐标
double derivative; // 插值点的导数值
// 默认构造函数
InterpolationPoint() : x(0.0), y(0.0), derivative(0.0) {}
// 带参数的构造函数
InterpolationPoint(double x_val, double y_val, double derivative_val) : x(x_val), y(y_val), derivative(derivative_val) {}
// 拷贝构造函数
InterpolationPoint(const InterpolationPoint& other) : x(other.x), y(other.y), derivative(other.derivative) {}
// 移动构造函数
InterpolationPoint(InterpolationPoint&& other) noexcept : x(other.x), y(other.y), derivative(other.derivative) {
other.x = 0.0;
other.y = 0.0;
other.derivative = 0.0;
}
// Copy assignment operator
InterpolationPoint& operator=(const InterpolationPoint& other) {
if (this != &other) {
x = other.x;
y = other.y;
derivative = other.derivative;
}
return *this;
}
// 设置插值点的值
void set(double x_val, double y_val, double derivative_val) {
x = x_val;
y = y_val;
derivative = derivative_val;
}
// 获取插值点的横坐标
double get_x() const {
return x;
}
// 获取插值点的纵坐标
double get_y() const {
return y;
}
// 获取插值点的导数值
double get_derivative() const {
return derivative;
}
};
class HermiteInterpolator {
public:
HermiteInterpolator(const std::vector<InterpolationPoint>& points);
HermiteInterpolator(int width, std::vector<int> &adjPoints);
void setPoints(const std::vector<InterpolationPoint>& points);
double interpolate(double x) ;
private:
// 返回连接两点的线段函数
std::function<double(double)> getLineFunction( InterpolationPoint& p1, InterpolationPoint& p2);
private:
std::vector<InterpolationPoint> points_;
};
#include "pch.h"
#include "HermiteInterpolator.h"
#include <fstream>
HermiteInterpolator::HermiteInterpolator(const std::vector<InterpolationPoint>& points)
: points_(points)
{
}
HermiteInterpolator::HermiteInterpolator(int width, std::vector<int>& adjPoints)
{
float step = width / adjPoints.size();
for (int i = 0; i < adjPoints.size(); i++)
{
InterpolationPoint point(step*i, adjPoints[i] , 0);
points_.push_back(point);
}
}
void HermiteInterpolator::setPoints(const std::vector<InterpolationPoint>& points)
{
points_ = points;
}
// 返回连接两点的线段函数
std::function<double(double)> HermiteInterpolator::getLineFunction( InterpolationPoint& p1, InterpolationPoint& p2) {
// 计算线段的斜率和截距
double slope = (p2.y - p1.y) / (p2.x - p1.x);
double intercept = p1.y - slope * p1.x;
// 返回线段的lambda表达式
return [slope, intercept](double x) {
return slope * x + intercept;
};
}
// 计算三次分段Hermite插值函数的值
double HermiteInterpolator::interpolate(double x) {
int y = 0;
int n = points_.size();
if (n < 3)
{
// 获取线段函数
std::function<double(double)> lineFunction = getLineFunction(points_[0], points_[1]);
y= lineFunction(x);
}
else
{
for (int i = 0; i < n - 1; i++) {
if (x >= points_[i].x && x <= points_[i + 1].x) {
double h = points_[i + 1].x - points_[i].x;
double t = (x - points_[i].x) / h;// (x-x_k)/(x_{k+1} - x_k)
double tk = (x - points_[i + 1].x) / (-h); // (x - x_{ k + 1 }) / (x_k - x_{ k + 1 })
double y0 = (1 + 2 * t) * tk * tk;
double y1 = (1 + 2 * tk) * t * t;
double y2 = (x - points_[i].x) * tk * tk;
double y3 = (x - points_[i + 1].x) * t * t;
y= points_[i].y * y0 + points_[i + 1].y * y1 + points_[i].derivative * y2 + points_[i + 1].derivative * y3;
}
}
}
//ofstream f;
//f.open("D:\\work\\documentation\\HermiteInterpolator.txt", ios::app);
//f <<x<<"," << y << endl;
//f.close();
return y; // 如果找不到对应的插值段,返回默认值
}
为了可视化效果可以把结果写到HermiteInterpolator.txt
画图python代码:
import matplotlib.pyplot as plt
# 打开文本文件进行读取
with open('D:\\work\\documentation\\HermiteInterpolator.txt') as f:
data = f.readlines()
# 定义两个列表分别存储横坐标和纵坐标的数据
x = []
y = []
# 遍历每一行
for i, line in enumerate(data):
# 去除换行符
if line:
user_pwd_list = line.strip().split(',')
# 横坐标是行号
x.append(float(user_pwd_list[0]))
# 纵坐标是数值数据
y.append(float(user_pwd_list[1]))
# 创建散点图
plt.scatter(x, y)
# 添加标题和轴标签
plt.title('Scatter Plot')
plt.xlabel('Line')
plt.ylabel('Value')
# 显示并保存图像
#plt.savefig('plot.png')
plt.show()
python 版本
class point:
def __init__(self, x, y):
self.x = x
self.y = y
self.derivative=0
class HermiteInterpolator:
def __init__(self, points, ifcalculatederivative=True):
self.points = points
if ifcalculatederivative:
self.calculatederivative()
print("point: " ,ifcalculatederivative,"\n")
for i in range(len(self.points)):
print(self.points[i].x , self.points[i].y, self.points[i].derivative)
print(" \n")
def calculatederivative(self):
n = len(self.points)-1
for i in range(1,n):
self.points[i].derivative = (self.points[i + 1].y - self.points[i - 1].y) / (self.points[i + 1].x - self.points[i - 1].x)
if n>0:
self.points[0].derivative = (self.points[1].y - self.points[0].y) / (self.points[1].x - self.points[0].x)
self.points[n].derivative = (self.points[n].y - self.points[n-1].y) / (self.points[n].x - self.points[n-1].x)
def interpolate(self, x):
result = 0
for i in range(len(self.points)):
result += self.coefficients[i] * self.hermitePolynomial(i, x)
return result
def calculateLine(self, x, point1, point2):
return (point2.y - point1.y) / (point2.x - point1.x) * (x - point1.x) + point1.y
def hermiteBase(self, x,point1,point2):
result = 0
h = point2.x - point1.x
t = (x - point1.x) / h;# (x-x_k)/(x_{k+1} - x_k)
tk = (x - point2.x)/(-h) #; // (x - x_{ k + 1 }) / (x_k - x_{ k + 1 })
y0 = (1 + 2 * t) * tk * tk
y1 = (1 + 2 * tk) * t * t
y2 = (x - point1.x) * tk * tk
y3 = (x - point2.x) * t * t
result= point1.y * y0 + point2.y * y1 + point1.derivative * y2 + point2.derivative * y3
return result
def HermiteInterpolate(self, x):
y=0
n = len(self.points)
points_ = self.points
for i in range(n-1):
if x >= points_[i].x and x <= points_[i + 1].x:
y =self.hermiteBase(x, points_[i], points_[i+1])
return y
def HermiteInterpolateAdvance(self, x):
y=0
n = len(self.points)
points_ = self.points
for i in range(n-1):
if x >= points_[i].x and x <= points_[i + 1].x:
if(points_[i].y == points_[i+1].y):
y= self.calculateLine(x, points_[i], points_[i+1])
else:
y =self.hermiteBase(x, points_[i], points_[i+1])
return y
def plotHermiteFunction():
points = [point(0,0) , point(1,1), point(2,0), point(3,1), point(4,1), point(5,1)]
points2 = [point(0,0) , point(1,1), point(2,0), point(3,1), point(4,1), point(5,1)]
hermiteInterpolator = HermiteInterpolator(points)
x = np.linspace(0, 5, 200)
#y1 = [hermiteInterpolator.HermiteInterpolateAdvance(i) for i in x]
hermiteInterpolator2 = HermiteInterpolator(points2,False)
y = [hermiteInterpolator2.HermiteInterpolate(i) for i in x]
y2 = [hermiteInterpolator2.HermiteInterpolateAdvance(i) for i in x]
#plt.plot(x, y, label='hermite',color='deepskyblue')
plt.plot(x, y2, label='hermiteAdvance',color='red')
#plt.plot(x, y1, label='hermiteAdvance',color='green')
scatterX = [0,1,2,3,4,5]
scatterY = [0,1,0,1,1,1]
plt.scatter(scatterX, # 横坐标
scatterY, # 纵坐标
c='yellow', # 点的颜色
label='点') # 标签 即为点代表的意思
plt.legend()
plt.show()