kalman滤波
在学kalman滤波的原理,希望通过python能加深对公式和原理的理解。
记录一下
import numpy as np
import math
#import matplotlib.pyplot as plt
'''
dynam_params:状态空间的维数;
measure_params:测量值的维数;
control_params:控制向量的维数,默认为0。
'''
class Kalman(object):
'''
INIT KALMAN
'''
def __init__(self, dynam_params, measure_params,control_params = 0,type = np.float32):
self.dynam_params = dynam_params
self.measure_params = measure_params
self.control_params = control_params
# self
### 以下都应该可以根据输入的维度值确定维度的
if(control_params != 0):
self.controlMatrix = np.array(np.zeros((dynam_params, control_params)),type) # 控制矩阵
else:
self.controlMatrix = None
self.errorCovPost = np.array(np.zeros((dynam_params, dynam_params)),type) # P_K
self.errorCovPre = np.array(np.zeros((dynam_params, dynam_params)),type) # P_k-1
self.gain = np.array(np.zeros((dynam_params, measure_params)),type) # K
self.measurementMatrix = np.array(np.zeros((measure_params, dynam_params)),type) # 测量矩阵 H
self.measurementNoiseCov = np.array(np.zeros((measure_params, measure_params)),type) # 测量噪声 R
self.processNoiseCov = np.array(np.zeros((dynam_params, dynam_params)),type) # 过程噪声 Q
self.transitionMatrix = np.array(np.zeros((dynam_params, dynam_params)),type) # 状态转移矩阵 F
self.statePost = np.array(np.zeros((dynam_params, 1)),type)
self.statePre = np.array(np.zeros((dynam_params, 1)),type)
# 对角线初始化为1
## np.diag_indices 以元组的形式返回主对角线的索引
### F 状态转移矩阵 对角线初始化为1
row,col = np.diag_indices(self.transitionMatrix.shape[0])
self.transitionMatrix[row,col] = np.array(np.ones(self.transitionMatrix.shape[0]))
### R 测量噪声 对角线初始化为1
row,col = np.diag_indices(self.measurementNoiseCov.shape[0])
self.measurementNoiseCov[row,col] = np.array(np.ones(self.measurementNoiseCov.shape[0]))
### Q 过程噪声 对角线初始化为1
row,col = np.diag_indices(self.processNoiseCov.shape[0])
self.processNoiseCov[row,col] = np.array(np.ones(self.processNoiseCov.shape[0]))
def predict(self,control_vector = None):
'''
PREDICT
'''
# 预测值
F = self.transitionMatrix
x_update = self.statePost
B = self.controlMatrix
if(self.control_params == 0):
x_predict = np.dot(F, x_update)
else:
x_predict = np.dot(F, x_update) + np.dot(B, control_vector)
self.statePre = x_predict
# P_k
P_k_minus = self.errorCovPost
Q = self.processNoiseCov
temp1 = np.dot(F, P_k_minus)
self.errorCovPre = np.dot(temp1, F.T) + Q
#self.errorCovPre = F * P_k_minus * F.T + Q
self.statePost = self.statePre
self.errorCovPost = self.errorCovPre
return x_predict
def correct(self,mes):
'''
CORRECT
'''
# K 更新
K = self.gain
P_k = self.errorCovPost
H = self.measurementMatrix
R = self.measurementNoiseCov
temp1 = np.dot(P_k, H.T)
temp2 = np.dot(H, P_k)
temp3 = np.dot(temp2, H.T) + R
K = np.dot(temp1, np.linalg.inv(temp3))
#K = P_k * H.T * np.linalg.inv(H * P_k * H.T + R)
self.gain = K
# 计算State的估计值
x_predict = self.statePre
temp4 = mes - np.dot(H, x_predict)
temp5 = np.dot(K, temp4)
x_update = x_predict + temp5
print("H",H)
print("x_predict",x_predict)
print("temp4",temp4)
print("temp5",temp5)
print("update",x_update)
# x_update = x_predict + K * (mes - H * x_predict)
self.statePost = x_update
print("====",self.statePost)
print(self.statePre)
print(self.statePre[0])
print(self.statePost[0])
# P_k更新
P_pre = self.errorCovPre
temp5 = np.dot(K, H)
temp6 = np.dot(temp5, P_pre)
P_k_post = P_pre - temp6
# P_k_post = P_pre - K * H * P_pre
self.errorCovPost = P_k_post
return x_update
if __name__ == '__main__':
pos = np.array([
[10, 50],
[12, 49],
[11, 52],
[13, 52.2],
[12.9, 50]], np.float32)
kalman = Kalman(2,2)
kalman.measurementMatrix = np.mat([[1,0],[0,1]],np.float32)
kalman.transitionMatrix = np.mat([[1,0],[0,1]], np.float32)
kalman.processNoiseCov = np.mat([[1,0],[0,1]], np.float32) * 1e-4
kalman.measurementNoiseCov = np.mat([[1,0],[0,1]], np.float32) * 1e-4
kalman.statePre = np.mat([[6],[6]],np.float32)
#kalman.statePre = np.mat([[6],[6]],np.float32)
for i in range(len(pos)):
mes = np.reshape(pos[i,:],(2,1))
y = kalman.predict()
print("before correct mes",mes[0],mes[1])
x = kalman.correct(mes)
print (kalman.statePost[0],kalman.statePost[1])
print (kalman.statePre[0],kalman.statePre[1])
print ('measurement:\t',mes[0],mes[1])
print ('correct:\t',x[0],x[1])
print ('predict:\t',y[0],y[1])
print ('='*30)
跟上一篇用Python提供的函数测试了一下输出,是一样的。
小白,有问题请多指教。