机器学习EM算法-硬币问题-详细可视化代码

致读者: 博主是一名数据科学与大数据专业大二的学生,真正的一个互联网萌新,写博客一方面是为了记录自己的学习过程中遇到的问题和思考,一方面是希望能够帮助到很多和自己一样处于困惑的读者。
由于水平有限,博客中难免会有一些错误,有纰漏之处恳请各位大佬不吝赐教!之后会写大数据专业的文章哦。
GitHub链接https://github.com/wfy-belief
尽管现在我的水平可能还不太及格,但我会尽我自己所能,做到最好☺。——天地有正气,杂然赋流形。下则为河岳,上则为日星。


算法简介

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

算法步骤

在这里插入图片描述

代码效果展示

第一次

在这里插入图片描述

第二次

在这里插入图片描述

第n次

代码实现

import math
import matplotlib.pyplot as plt


class Coin_Distribution():
    """
    硬币的分布
    """

    def __init__(self, head, tail):
        """
        硬币的正面和反面的个数
        """
        self.head = head
        self.tail = tail


class Solution():
    def __init__(self, theta_A, theta_B):
        """
        初始化条件
        """
        self.distribution1 = Coin_Distribution(5, 5)
        self.distribution2 = Coin_Distribution(9, 1)
        self.distribution3 = Coin_Distribution(8, 2)
        self.distribution4 = Coin_Distribution(4, 6)
        self.distribution5 = Coin_Distribution(7, 3)
        self.distribution = {
            '1': 0,
            '2': 0,
            '3': 0,
            '4': 0,
            '5': 0,
        }
        self.A = theta_A
        self.B = theta_B

    def get_distribution(self, head, tail, key):
        """
        求每个分布的概率,A投掷十次出现相应分布的概率
        """
        pa = math.pow(self.A, head) * math.pow(1 - self.A, tail)
        pb = math.pow(self.B, head) * math.pow(1 - self.B, tail)
        self.distribution[key] = pa / (pa + pb)
        # print(pa / (pa + pb))

    def get_five_distribution(self):
        self.get_distribution(self.distribution1.head,
                              self.distribution1.tail, '1')
        self.get_distribution(self.distribution2.head,
                              self.distribution2.tail, '2')
        self.get_distribution(self.distribution3.head,
                              self.distribution3.tail, '3')
        self.get_distribution(self.distribution4.head,
                              self.distribution4.tail, '4')
        self.get_distribution(self.distribution5.head,
                              self.distribution5.tail, '5')

    def M(self):
        sum_head_A = 0
        sum_tail_A = 0
        sum_head_B = 0
        sum_tail_B = 0
        # get A head
        # print(self.distribution['1'], self.distribution1.head)
        sum_head_A += self.distribution['1'] * self.distribution1.head
        sum_head_A += self.distribution['2'] * self.distribution2.head
        sum_head_A += self.distribution['3'] * self.distribution3.head
        sum_head_A += self.distribution['4'] * self.distribution4.head
        sum_head_A += self.distribution['5'] * self.distribution5.head
        # get A tail
        sum_tail_A += self.distribution['1'] * self.distribution1.tail
        sum_tail_A += self.distribution['2'] * self.distribution2.tail
        sum_tail_A += self.distribution['3'] * self.distribution3.tail
        sum_tail_A += self.distribution['4'] * self.distribution4.tail
        sum_tail_A += self.distribution['5'] * self.distribution5.tail

        # get B heBd
        sum_head_B += (1 - self.distribution['1']) * self.distribution1.head
        sum_head_B += (1 - self.distribution['2']) * self.distribution2.head
        sum_head_B += (1 - self.distribution['3']) * self.distribution3.head
        sum_head_B += (1 - self.distribution['4']) * self.distribution4.head
        sum_head_B += (1 - self.distribution['5']) * self.distribution5.head
        # get B tBil
        sum_tail_B += (1 - self.distribution['1']) * self.distribution1.tail
        sum_tail_B += (1 - self.distribution['2']) * self.distribution2.tail
        sum_tail_B += (1 - self.distribution['3']) * self.distribution3.tail
        sum_tail_B += (1 - self.distribution['4']) * self.distribution4.tail
        sum_tail_B += (1 - self.distribution['5']) * self.distribution5.tail

        self.A = sum_head_A / (sum_head_A + sum_tail_A)
        self.B = sum_head_B / (sum_head_B + sum_tail_B)
        print(self.A, self.B)

    def init_image(self):
        fig = plt.figure('EM', figsize=(4.5, 7))
        plt.title('EM')
        plt.xticks([])
        plt.yticks([])

    def drow_image(self, num):
        plt.title('This is the %d time EM' % num)
        plt.plot([1, 1, 5, 5, 1, 1, 5, 5, 1, 1, 5, 5, 1, 1, 5, 5, 1, 1, 3, 3],
                 [1, 8, 8, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 1])
        plt.text(2, 1.5, r'$%s$' % str('COIN-A'), ha='center', c='r')
        plt.text(4, 1.5, r'$%s$' % str('COIN-B'), ha='center', c='g')
        # coin A
        # print(self.distribution['1'], self.distribution1.head)

        plt.text(1.5, 2.5, r'$%.1fH,%.1fT$' % (self.distribution['1'] * self.distribution1.head,
                                               self.distribution['1'] * self.distribution1.tail), c='r')
        plt.text(1.5, 3.5, r'$%.1fH,%.1fT$' % (self.distribution['2'] * self.distribution2.head,
                                               self.distribution['2'] * self.distribution2.tail), c='r')
        plt.text(1.5, 4.5, r'$%.1fH,%.1fT$' % (self.distribution['3'] * self.distribution3.head,
                                               self.distribution['3'] * self.distribution3.tail), c='r')
        plt.text(1.5, 5.5, r'$%.1fH,%.1fT$' % (self.distribution['4'] * self.distribution4.head,
                                               self.distribution['4'] * self.distribution4.tail), c='r')
        plt.text(1.5, 6.5, r'$%.1fH,%.1fT$' % (self.distribution['5'] * self.distribution5.head,
                                               self.distribution['5'] * self.distribution5.tail), c='r')
        # coin B
        plt.text(3.5, 2.5, r'$%.1fH,%.1fT$' % ((1 - self.distribution['1']) * self.distribution1.head,
                                               (1 - self.distribution['1']) * self.distribution1.tail), c='g')
        plt.text(3.5, 3.5, r'$%.1fH,%.1fT$' % ((1 - self.distribution['2']) * self.distribution2.head,
                                               (1 - self.distribution['2']) * self.distribution2.tail), c='g')
        plt.text(3.5, 4.5, r'$%.1fH,%.1fT$' % ((1 - self.distribution['3']) * self.distribution3.head,
                                               (1 - self.distribution['3']) * self.distribution3.tail), c='g')
        plt.text(3.5, 5.5, r'$%.1fH,%.1fT$' % ((1 - self.distribution['4']) * self.distribution4.head,
                                               (1 - self.distribution['4']) * self.distribution4.tail), c='g')
        plt.text(3.5, 6.5, r'$%.1fH,%.1fT$' % ((1 - self.distribution['5']) * self.distribution5.head,
                                               (1 - self.distribution['5']) * self.distribution5.tail), c='g')
        # coni A and B
        plt.text(1.2, 7.5, r'CoinA is $p_{A}(%.2f)$' % self.A, c='k')
        plt.text(3.2, 7.5, r'CoinB is $p_{B}(%.2f)$' % self.B, c='k')
        # describe
        plt.text(1.5, 8.5, r'Please click on the X to continue...', c='b')

    def set_ax(self):
        ax = plt.gca()
        ax.spines['right'].set_color('none')
        ax.spines['bottom'].set_color('none')
        ax.spines['top'].set_color('none')
        ax.spines['left'].set_color('none')
        ax.xaxis.set_ticks_position('top')
        ax.invert_yaxis()

    def show_image(self, num):
        self.init_image()
        self.set_ax()
        self.drow_image(num)
        plt.show()


if __name__ == "__main__":
    S = Solution(0.6, 0.5)
    for i in range(5):
        S.get_five_distribution()
        S.M()
        S.show_image(i + 1)

评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ZZULI_星.夜

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值