EM算法学习

EM算法学习记录

  • 开始学机器学习,٩(๑>◡<๑)۶

算法思路

参照 https://www.cnblogs.com/jerrylead/archive/2011/04/06/2006936.html

实现代码

import math
import numpy

# 函数原型:
# Jensen 不等式
# f_2(x) >= 0 (凹函数);E(F(X)) >= F(E(x))

class gaussian_distribution:
    def __init__(self,mu,sigma):
        # 高斯模型
        gaussian_distribution.mu = mu
        # sigma代表标准差(math.sqrt(方差))
        gaussian_distribution.sigma = sigma
        # 可能性
        gaussian_distribution.probability = 0.0

    def get_value(self,x):
        # 高斯模型
        if type(x) == type(list()):            
            value = numpy.exp(-((numpy.mat(x) - numpy.mat(self.mu))*(numpy.mat(x) - numpy.mat(self.mu)).transpose()).sum()/(2*self.sigma**2))
            return value/(math.sqrt(2*numpy.pi)*self.sigma)
        # 这里可以增加对于矩阵的判断 
        else :
            # 二维
            return numpy.exp(-(x - self.mu)**2/2*self.sigma**2)/(math.sqrt(2*numpy.pi)*self.sigma)

    def change_mu(self,mu):
        self.mu = mu

    def change_sigma(self,sigma):
        self.sigma = sigma

    # receive a list of the value
    # return a list include average value
    def get_average_value(self,value_list):
        sum = numpy.zeros((1,len(value_list[0])))
        for list_reader in value_list:
            sum += numpy.mat(list_reader)
        sum /= len(value_list)
        return sum.tolist()[0]

    def get_sigma(self,value_list):
        sum = 0.0
        for list_reader in value_list:
            sum += numpy.mat(list_reader)*numpy.mat(list_reader).transpose()
        sum /= len(value_list)
        return math.sqrt(sum)

class EM_gauss:
    def __init__(self,k,db):
        EM_gauss.gauss_dis_list = list()
        EM_gauss.data_list = db
        # 初始化各个模型
        # k代表初始有k个高斯模型
        for reader in range(k):
            gauss_dis = gaussian_distribution(0,1)
            gauss_dis.mu = gauss_dis.get_average_value(db)
            gauss_dis.sigma = gauss_dis.get_sigma(db)
            gauss_dis.probability = 1/k
            EM_gauss.gauss_dis_list.append(gauss_dis)


    def get_data_probability(self,data_index,model_index):
        value = self.gauss_dis_list[model_index].probability*self.gauss_dis_list[model_index].get_value(self.data_list[data_index])
        sum = 0.0
        for gauss_dis in self.gauss_dis_list:
            sum += gauss_dis.probability*gauss_dis.get_value(self.data_list[data_index])
        return value/sum

    # 模型存在的可能性
    def get_model_probability(self,model_index):
        sum = 0.0
        for i in range(len(self.data_list)):
            sum += self.get_data_probability(i-1,model_index)
        return sum

    def round(self):
        # E
        model_probability_list = list()
        # 每个模型获取它存在的可能性
        for i in range(len(self.gauss_dis_list)):
            model_probability_list.append(self.get_model_probability(i-1))
        # M
        # 维护μ
        for i in range(len(self.gauss_dis_list)):
            mu = numpy.zeros((1,len(self.gauss_dis_list[0].mu)))
            for j in range(len(self.data_list)):
                mu += self.get_data_probability(j-1,i-1)*numpy.mat(self.data_list[j-1])
            self.gauss_dis_list[i-1].mu = (mu/model_probability_list[i-1]).tolist()[0]
        # 维护σ
        for i in range(len(self.gauss_dis_list)):
            sigma = 0.0
            for j in range(len(self.data_list)):
                sigma += self.get_data_probability(j-1,i-1)*(numpy.mat(self.data_list[j-1])-self.gauss_dis_list[i-1].mu)*(numpy.mat(self.data_list[j-1])-self.gauss_dis_list[i-1].mu).transpose()
            self.gauss_dis_list[i-1].sigma = (sigma/model_probability_list[i-1]).tolist()[0][0]
        # 维护每个数据符合高斯模型的可能性
        for i in range(len(self.gauss_dis_list)):
            self.gauss_dis_list[i-1].probability = model_probability_list[i-1]/len(self.data_list)

    def rounds(self,error = 0.0001):
        # 多轮运算,直到获得误差小于1-error的模型
        record_last = list()
        null_list = list()
        for list_reader in self.gauss_dis_list:
            if list_reader.sigma == 0:
                null_list.append(elf.gauss_dis_list.index(list_reader))
                continue
            record_last += list_reader.mu
            record_last.append(list_reader.sigma)
        while null_list != []:
            index = null_list.pop()
            del self.gauss_dis_list[index]
        self.round()
        while not self.check_gauss_equal(record_last):
            # for i in self.gauss_dis_list:
            #     print(i.mu)
            #     print(i.sigma)
            #     print(i.probability)
            record_last = []
            for list_reader in self.gauss_dis_list:
                if list_reader.sigma == 0:
                    null_list.append(self.gauss_dis_list.index(list_reader))
                    continue
                record_last += list_reader.mu
                record_last.append(list_reader.sigma)
            while null_list != []:
                index = null_list.pop()
                del self.gauss_dis_list[index]
            self.round()
        # for i in self.gauss_dis_list:
        #     print(i.mu)
        #     print(i.sigma)
        #     print(i.probability)

    # 需保证两个list长度相同
    # 检查两个高斯模型是否等价
    def check_gauss_equal(self,list_a,error = 0.0001):
        list_b = list()
        for list_reader in self.gauss_dis_list:
            list_b += list_reader.mu
            list_b.append(list_reader.sigma)
        for i in range(len(list_a)):
            if abs(list_a[i-1] - list_b[i-1]) > error:
                return False
        return True

# 数据的标准格式 
# num_x num_y num_z ...
# ...
# (均为数字)
class point_data_reader:
    # 获取点类数据
    file_name = str()
    def __init__(self,file_name):
        point_data_reader.file_name = file_name

    def get_data_list(self,num_lost):
        file_ = open(self.file_name,'r+')
        db = list()       
        for line in file_.readlines():
            reader_list = list()
            for reader_ in line.split():
                try:
                    eval(reader_)
                except:
                    reader_list.append(num_lost)  
                else:
                    reader_list.append(eval(reader_))     
            db.append(reader_list)      
        file_.close()
        return db

if __name__ == '__main__':
    # num_lost为填充丢失数据
    db = point_data_reader('text.dat').get_data_list(num_lost = 1.0)
    # 2为模型个数
    em = EM_gauss(2,db)
    em.rounds() 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值