机器学习基础:理解极大似然估计
本文将帮助你直观的理解为什么极大似然估计(Maximum Likelihood Estimation)可以用作模型参数的估计。
1. 什么是模型参数(Parameter)
每种模型内都存在着一系列参数,当使用不同数据时,模型参数会相应地改变.
最常见的模型之一就是线性模型: y = bx + c, 其中b和c就是该模型的参数。
比如我们用体重(kg)来估计身高(cm)时,体重每增加1公斤,身高就增加b厘米,c则是该直线的截距。
2.什么是似然(Likelihood):
在理解似然前,我们最好将其与概率作比较,概率与似然的定义分别如下:
- 概率:概率描述了当模型的参数确定时,出现特定事件的可能性。
- 似然:似然描述了当特定观测出现时,对该观测而言模型的好坏程度,当似然函数值越大时,模型参数越优。
请注意,当概率分布为连续概率分布时,由于任意单个事件的概率都为0,此时我们可以用概率密度来代替概率,这一点将在后文中用极大似然估计确定正态分布函数的参数这一例子中体现。
以投掷一枚硬币5次为例:
- 概率:假设我们投掷一枚质地均匀(即出现正面和反面的概率都是0.5)的硬币5次,连续五次都出现正面的概率是0.5的五次方,即0.03125。
- 似然(用L表示):假设一枚硬币的质地未知,有可能正面重反面轻,也有可能正面轻反面重,从而导致投掷后出现正反的概率不同。为了确定硬币的质地,我们可以构建一个模型f(P),参数P表示硬币的质地或出现正面的概率。
- 假设我们投掷了5次硬币,且每次结果都是正面。
- 假设P = 0.3时,连续出现五次正面或模型的似然函数值L(P=0.3 | 连续出现5次正面) = 0.3 ^ 5 = 0.0024
- 假设P = 0.5时,连续出现五次正面或模型的似然函数值L(P=0.5 | 连续出现5次正面) = 0.5 ^ 5 = 0.03125
- 假设P = 0.8时,连续出现五次正面或模型的似然函数值L(P=0.8 | 连续出现5次正面) = 0.8 ^ 5 = 0.32768
- 通过观测结果而言,在上述的三个参数之中,P=0.8时的似然函数值(L)最大,因此我们最有理由相信0.8是该模型的最优参数;用通俗的语言来解释,因为连续五次投掷硬币都结果都是正面,所以我们更有理由相信该硬币的质地为正面重反面轻,所以P=0.8这一参数最为合理。
3.什么是对数似然函数(Log Likelihood)
定义上对数似然函数就是似然函数的对数,即:
LL = log(L)
其中L表示似然函数,LL表示对数似然函数,log是以e为底的对数。
那么我们为什么需要对数似然函数(Log Likelihood)这一概念呢?其中一个理由如下:
通过上述投硬币的例子可以知道,我们通过连乘一系列小于1的数字来计算似然函数,当观测的样本量很大时,似然度函数值会趋近于0,受限于计算机储存数据的精度,该数值很可能无法被正确计算。因此我们需要引入对数似然函数这一概念,并利用对数运算中log(a * b)等价于log(a) + log(b)的这一性质来比较似然度。又因为对数函数是单调递增的,因而比较对数似然函数和比较似然函数本身是等价的。
假设我们观测了10000次硬币的投掷结果,且每一次都是正面:
- 当我们用Python计算0.5的一万次方(0.5 ** 10000)和0.8的一万次方(0.8 ** 10000)时,虽然我们知道后者的值更大,但Python对两者的计算结果都是0.0,从而我们无法比较参数P=0.5和P=0.8时模型的好坏。
- 但当我们计算对数似然函数时(用LL表示)时:
- P=0.5时,LL(P=0.5 | 连续出现10000次正面)= log(0.5 ^ 10000) = 10000 * log(0.5) = -6931
- P=0.8时,LL(P=0.8 | 连续出现10000次正面)= log(0.8 ^ 10000) = 10000 * log(0.8) = -2231
- 在上述两个参数中,因为 -2231 > -6931, 所以当参数P=0.8时模型最优。通过对数似然函数,我们回避了因计算机储存精度不足而导致的错误。
4. 什么是极大似然估计(Maximum Likelihood Estimation)
极大似然估计就是用似然函数来决定模型最优参数的算法。该算法的核心思想就是当特定观测出现后,通过最最优化算法找到一系列参数使模型的似然函数值最大,此时模型的参数就是最优参数。
以上解释可能仍不太直观,接下来我们讲通过一个例子来作具体说明:
5. 用极大似然估计确定正态分布函数的参数
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
from functools import reduce
from math import log
np.random.seed(seed=39)
# define function to draw normal distribution
def plot_normal(model,name,color):
m