Kalman滤波 python代码实现

kalman滤波

在学kalman滤波的原理,希望通过python能加深对公式和原理的理解。

记录一下

import numpy as np
import math
#import matplotlib.pyplot as plt

'''
    dynam_params:状态空间的维数;
    measure_params:测量值的维数; 
    control_params:控制向量的维数,默认为0。
'''
class Kalman(object):
    '''
        INIT KALMAN
    '''
    def __init__(self, dynam_params, measure_params,control_params = 0,type = np.float32):
        self.dynam_params = dynam_params
        self.measure_params = measure_params
        self.control_params = control_params
        # self

        ### 以下都应该可以根据输入的维度值确定维度的
        if(control_params != 0):
            self.controlMatrix = np.array(np.zeros((dynam_params, control_params)),type) # 控制矩阵
        else:
            self.controlMatrix = None
        self.errorCovPost =  np.array(np.zeros((dynam_params, dynam_params)),type) # P_K
        self.errorCovPre =  np.array(np.zeros((dynam_params, dynam_params)),type) # P_k-1
        self.gain =  np.array(np.zeros((dynam_params, measure_params)),type) # K

        self.measurementMatrix =  np.array(np.zeros((measure_params, dynam_params)),type) # 测量矩阵 H
        self.measurementNoiseCov =  np.array(np.zeros((measure_params, measure_params)),type) # 测量噪声 R
        self.processNoiseCov =  np.array(np.zeros((dynam_params, dynam_params)),type) # 过程噪声 Q
        self.transitionMatrix =  np.array(np.zeros((dynam_params, dynam_params)),type) # 状态转移矩阵 F

        self.statePost =  np.array(np.zeros((dynam_params, 1)),type)
        self.statePre =  np.array(np.zeros((dynam_params, 1)),type)
        
        # 对角线初始化为1
        ## np.diag_indices 以元组的形式返回主对角线的索引
        ### F 状态转移矩阵 对角线初始化为1
        row,col = np.diag_indices(self.transitionMatrix.shape[0])
        self.transitionMatrix[row,col] = np.array(np.ones(self.transitionMatrix.shape[0]))
        ### R 测量噪声 对角线初始化为1
        row,col = np.diag_indices(self.measurementNoiseCov.shape[0])
        self.measurementNoiseCov[row,col] = np.array(np.ones(self.measurementNoiseCov.shape[0]))
        ### Q 过程噪声 对角线初始化为1
        row,col = np.diag_indices(self.processNoiseCov.shape[0])
        self.processNoiseCov[row,col] = np.array(np.ones(self.processNoiseCov.shape[0]))
    
    def predict(self,control_vector = None):
        '''
            PREDICT
        '''
        # 预测值
        F = self.transitionMatrix
        x_update = self.statePost
        B = self.controlMatrix

        if(self.control_params == 0):
            x_predict = np.dot(F, x_update)
        else:
            x_predict = np.dot(F, x_update) + np.dot(B, control_vector)
        self.statePre = x_predict

        # P_k
        P_k_minus = self.errorCovPost
        Q = self.processNoiseCov
        temp1 = np.dot(F, P_k_minus)
        self.errorCovPre = np.dot(temp1, F.T) + Q
        #self.errorCovPre = F * P_k_minus * F.T + Q

        self.statePost = self.statePre
        self.errorCovPost = self.errorCovPre
        return x_predict

    def correct(self,mes):
        '''
            CORRECT
        '''
        # K 更新
        K = self.gain
        P_k = self.errorCovPost
        H = self.measurementMatrix
        R = self.measurementNoiseCov

        temp1 = np.dot(P_k, H.T)
        temp2 = np.dot(H, P_k)
        temp3 = np.dot(temp2, H.T) + R
        K = np.dot(temp1, np.linalg.inv(temp3))
        #K = P_k * H.T * np.linalg.inv(H * P_k * H.T + R)
        self.gain = K

        # 计算State的估计值
        x_predict = self.statePre
        temp4 = mes - np.dot(H, x_predict)
        temp5 = np.dot(K, temp4)
        x_update = x_predict + temp5
        print("H",H)
        print("x_predict",x_predict)
        print("temp4",temp4)
        print("temp5",temp5)
        print("update",x_update)
        
        # x_update = x_predict +  K * (mes - H * x_predict)
        self.statePost = x_update
        print("====",self.statePost)
        print(self.statePre)
        print(self.statePre[0])
        print(self.statePost[0])
        # P_k更新
        P_pre = self.errorCovPre
        temp5 = np.dot(K, H)
        temp6 = np.dot(temp5, P_pre)
        P_k_post = P_pre - temp6
        # P_k_post = P_pre - K * H * P_pre
        self.errorCovPost = P_k_post
        return x_update
        
if __name__ == '__main__':
    
    pos = np.array([
        [10,    50],
        [12,    49],    
        [11,    52],     
        [13,    52.2],     
        [12.9,  50]], np.float32)
    kalman = Kalman(2,2)
    kalman.measurementMatrix = np.mat([[1,0],[0,1]],np.float32)
    kalman.transitionMatrix = np.mat([[1,0],[0,1]], np.float32)
    kalman.processNoiseCov = np.mat([[1,0],[0,1]], np.float32) * 1e-4
    kalman.measurementNoiseCov = np.mat([[1,0],[0,1]], np.float32) * 1e-4 
    kalman.statePre =  np.mat([[6],[6]],np.float32)
   #kalman.statePre =  np.mat([[6],[6]],np.float32)
    for i in range(len(pos)):
        mes = np.reshape(pos[i,:],(2,1))
        y = kalman.predict()
        
        print("before correct mes",mes[0],mes[1])
        x = kalman.correct(mes)
        print (kalman.statePost[0],kalman.statePost[1])
        print (kalman.statePre[0],kalman.statePre[1])
        print ('measurement:\t',mes[0],mes[1])  
        print ('correct:\t',x[0],x[1])
        print ('predict:\t',y[0],y[1])     
        print ('='*30)  
    

跟上一篇用Python提供的函数测试了一下输出,是一样的。
小白,有问题请多指教。

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值