python fastICA

import numpy as np
import math
import random
import matplotlib.pyplot as plt
import scipy.io as scio
from scipy.fft import fft

'''
函数名称:center_data
函数功能:对数据中心化:对输入矩阵的每个元素,都减去该元素所在行(每一行一共有m个元素)的均值
输入参数:X          要处理的矩阵,大小为(n,m)
返回参数:X_center   进行中心化处理之后的矩阵,大小为(n,m)
'''


def center_data(X):
    # 沿着行的方向取均值,即计算n个麦克风在m个时刻中的均值,X_means的shape是(n,)
    X_means = np.mean(X, axis=1)
    # 将X_means增加一个新行,shape变为(n,1),X的每一列都与之对应相减
    return X - X_means[:, np.newaxis]                #np.newaxis的作用是增加一个维度。


'''
函数名称:whiten_data
函数功能:对数据白化处理
输入参数:X          要处理的矩阵,大小为(n,m)
返回参数:Z          白化处理之后的矩阵,大小为(n,m)
         V          白化变换矩阵
'''


def whiten_data(X):
    # 计算X的协方差矩阵,cov_X = E{XX^T)}
    cov_X = np.cov(X)
    # 计算协方差矩阵的特征值和特征向量
    eigenValue, eigenVector = np.linalg.eig(cov_X)
    # 将特征值向量对角化,变成对角阵,然后取逆
    eigenValue_inv = np.linalg.inv(np.diag(eigenValue))
    # 计算白化变换矩阵V
    V = np.dot(np.sqrt(eigenValue_inv), np.transpose(eigenVector))
    # 计算白化处理后得矩阵Z,Z=VX
    Z = np.dot(V, X)
    return Z, V


def gx(x, alpha=1):
    n = x.shape[0]
    for i in range(n):
        #for j in range(n):
            x[i] = x[i] * np.exp((-alpha * x[i]**2) / 2)
    return x




def div_gx(x, alpha=1):
    n = x.shape[0]
    for i in range(n):
        #for j in range(n):
            x[i] = (1 - alpha * x[i] **2) * np.exp(( - alpha * x[i]**2) / 2)
    return x


'''
函数名称:decorrelation_data
函数功能:对数据(W)进行去相关
输入参数:W                       要处理的矩阵,大小为(n,n)
返回参数:W_decorrelation         去相关之后的W
'''


def decorrelation_data(W):
    # 对WW.T进行特征值分解,D是特征值,P是特征向量
    D, P = np.linalg.eigh(np.dot(W, np.transpose(W)))
    # 特征值对角化,然后取逆
    div_D = np.linalg.inv(np.diag(D))
    # W_decorrelation = PD^(-1/2)P.T W
    return np.dot(np.dot(np.dot(P, np.sqrt(div_D)), np.transpose(P)), W)


'''
函数名称:FastICA
函数功能:对输入矩阵做ICA处理
输入参数:Z         输入矩阵(观测矩阵中心化白化之后的结果),大小为(n,m)
返回参数:W         ICA算法估计的W
        iter_num   ICA迭代次数
'''


def FastICA(Z):
    i = 0
    VariableNum,SampleNum = Z.shape
    print(VariableNum)
    print(SampleNum)
    # create w,随机生成W的值
    w = np.ones((VariableNum, VariableNum), np.float32)
    W = np.zeros((VariableNum, VariableNum), np.float32)
    for i in range(VariableNum):
        for j in range(VariableNum):
            W[i, j] = 2 * random.random() - 0.5

    # 迭代compute W
    maxIter = 500  # 设置最大迭代数量
    for j in range(VariableNum):

        while(i < maxIter):

                W_back = W[:,j]
                t = np.dot(np.transpose(Z),W[:,j])
                g = gx(t)
                dg = div_gx(t)
                W[:,j] = np.dot(Z,g)/SampleNum-np.mean(dg)*W[:,j]

                W[:,j] = W[:,j] - np.dot(np.dot(W,np.transpose(W)),W[:,j])
                W[:, j] = W[:, j] / np.linalg.norm(W[:, j])
                if np.abs(np.dot(np.transpose(W[:,j]),W_back) - 1)< 1e-11:
                    for q in range(w.shape[0]):
                        w[q,j] = W[q,j]
                    break
                i = i+1
    return w

def my_fft(x,n=1024):
    x_fft = fft(x,n=n)
    X = x_fft[0:int(x_fft.shape[0]/2)]
    X = 2*abs(X)/n
    return X


def show_data(S1,S2,MixedS1,MixedS2,y11,y12):
    S1 = S1.reshape(1000,)
    S2 = S2.reshape(1000, )
    s1_fft = my_fft(S1,n=1024)
    s2_fft = my_fft(S2,n=1024)
    MixedS1_fft = my_fft(MixedS1,n=1024)
    MixedS2_fft = my_fft(MixedS2,n=1024)
    y11_fft = my_fft(y11,n=1024)
    y12_fft = my_fft(y12,n=1024)
    plt.rcParams['font.sans-serif'] = ['simhei']  # 添加中文字体为黑体
    plt.rcParams['axes.unicode_minus'] = False

    plt.figure(dpi=300)

    plt.subplot(6,2,1)
    plt.plot(S1)
    plt.title('原始信号yy1')
    plt.subplot(6,2,2)
    plt.plot(S2)
    plt.title('原始信号yy2')
    plt.subplot(6,2,3)
    plt.plot(s1_fft)
    plt.title('原始信号1频谱')
    plt.subplot(6,2,4)
    plt.plot(s2_fft)
    plt.title('原始信号2频谱')
    plt.subplot(6,2,5)
    plt.plot(MixedS1)
    plt.title('混合信号1')
    plt.subplot(6,2,6)
    plt.plot(MixedS2)
    plt.title('混合信号2')
    plt.subplot(6,2,7)
    plt.plot(MixedS1_fft)
    plt.title('混合信号1频谱')
    plt.subplot(6,2,8)
    plt.plot(MixedS2_fft)
    plt.title('混合信号2频谱')
    plt.subplot(6,2,9)
    plt.plot(y11)
    plt.title('分离信号1')
    plt.subplot(6,2,10)
    plt.plot(y12)
    plt.title('分离信号2')
    plt.subplot(6,2,11)
    plt.plot(y11_fft)
    plt.title('分离信号1频谱')
    plt.subplot(6,2,12)
    plt.plot(y12_fft)
    plt.title('分离信号2频谱')

    name_ = '实验仿真'
    plt.savefig('./results/' + name_ + '.jpg')
    plt.close()

# 主函数入口
def main():

    Mix_ = scio.loadmat('doubl.mat')
    Mix = Mix_['doubl']
    I1_ = scio.loadmat('stri.mat')
    I1 = I1_['stri']
    I2_ = scio.loadmat('vibra.mat')
    I2 = I2_['vibra']
    #读取数据

    S1 = I1[0,:]
    S2 = I2[0,:]
    S1 = S1.reshape(1,S1.shape[0])
    S2 = S2.reshape(1,S2.shape[0])
    S = np.concatenate((S1,S2),axis=0)
    print(S.shape)

    MixedS = Mix[0:2,:]
    print(MixedS.shape)
    MixedS1 = MixedS[0, :]
    MixedS2 = MixedS[1, :]


    MixedS = center_data(MixedS)
    MixedS_white,Z = whiten_data(MixedS)

    W = FastICA(MixedS_white)

    ICAedS = np.dot(np.transpose(W),MixedS)
    y11 = ICAedS[0,:]
    y12 = ICAedS[1,:]
    show_data(S1,S2,MixedS1,MixedS2,y11,y12)

if __name__ == "__main__":
    main()

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值