EM算法

EM(Expectaion Maximization)算法

更多:贝叶斯

9.1 简单介绍

EM(Expectaion Maximization)算法(又称为期望最大化方法)是一种迭代算法,Dempster等人在1977年总结提出来的。简单来说EM算法就是一种含有隐变量的概率模型参数的极大似然估计。EM算法的每次迭代由两步组成:第一是求期望,第二是求极大。EM算法在机器学习中有极为广泛的应用。如常被用来学习高斯混合模型(Gaussian mixture model, 简称GMM)的参数。

那么什么是含有隐变量的概率模型?这里举一个常用的三硬币例子,假设我们有三枚硬币:A、B和C,他们的质地都是不均匀的,假设他们正面朝上的概率分别是:a、b和c。现在弄一个抛硬币的规则,先抛A硬币,如果A正面朝上,那么就抛B硬币,否则就抛C硬币。最后记下最终结果,正面朝上记为1,否则记为0。现在进行10次该实验,假如得到的结果如下: 1,0,0,1,1,1,0,1,0,0。这个时候我们其实只得到了最终的结果,并不知道是B还是C硬币的结果,因为不知道每次A硬币的结果。这个时候A硬币的抛掷就可以认为是一个隐含变量。但是问题是如何根据这个结果来估计这三个参数呢?

9.2理论推导
9.2.1 算法思想

在解决例子问题之前,我们先进行一些所谓枯燥的数学化定义,这样或许能帮助理解和记忆。

简单阐述就是:其实这里有两类变量,一类是隐变量,一类是待求的参数变量。那么普通的思路该怎么求这个参数变量呢?由上述阐述可以知道,如果我们事先知道了隐变量就能利用极大似然来估计参数,如果我们知道了参数,那么我们可以计算出隐变量集的期望。这里就形成了一个制约,只要我们给出隐变量的初始值就能通过迭代达到两类变量之间的平衡,也就是收敛。类似于我们在生活中的称重,如果要将一类物品分为两部分(比如糖果),在没有称的情况下,往往我们在左右手进行掂量(这就有点像两类变量),如果左手上重了就分点到右手上,否则,从右手上扒拉点分到左手,直到感觉两只手上重量差不多。

所以这里就落下了两个最主要的问题

9.2.2 算法推导

重复上述两个步骤,直到收敛到局部最优解。详细的有关收敛问题,可以参考李航老师的《统计学习方法》。

这里就通过简单的方法介绍一下EM算法的核心思路,然后主要是通过以下几个例子来感受一下EM算法。

9.3 三枚硬币

这里我们首先借用一下李航老师在《统计学习方法》中的三枚硬币模型的例子进行阐述。

我们拆开来看,其实两项都是二分布(简单的抛掷硬币过程, 假设A为正面朝上,然后进行一次抛硬币,A反面朝上同样)。所以我们可以继续写。

我们可以简单的来核对一下这个概率模型写得对不对。我们画出一次抛掷硬币的整个过程,并计算出相应的概率。然后带入到上面的公式中就可以知道模型构建是否正确了。

重复多次的原理也是如此,只不过因为进行的多次独立实验,所以计算概率直接用连乘累积。多次独立实验概率模型如下。

但是这个似然函数(联乘)求导将非常复杂,所以在极大似然估计中一般都转换为对数似然函数。但是依然非常复杂,所以用EM算法迭代的思想来求解。

这里根据Step4我们得到了隐变量的期望(也就是得到了隐变量),由此可以直接对上述的(4)式进行求导,得到相应的参数。

Step6: 进行数值计算并迭代

如何判断停止呢?

9.4 一个简单例子

下面我们将编程实现一个最简单的正态分布参数估计的例子,感受一下EM算法。

case1: 
如下图所示是两组一维正态分布数据。这两组是不带隐含变量的正态分布,我们可以很明显的看出这两组正态分布数据,红色的点是一类,蓝色的是另外一类。也可以大概估计出这个例子中两组正态分布的均值,比如红色类别大概是3左右。这就是极大似然估计所处理的一个场景:不带隐变量的参数估计方法

case2:
如下图所示也告诉是两组一维正态分布的数据。但是是带有隐变量的,所以我们完全看不出两个类别,也不太好利用极大似然估计的方法来找到两组参数。这个时候就是EM算法的角斗场,利用反复迭代的思路。为了方便理解,我们完全可以理解为用一个滤波器一样的东西在给出的数据上滑动,看哪个滤波器(一组相应的正态分布的参数)能产生最小的误差。

下面我们将按照上述三枚硬币的例子写一小段代码来感受EM算法迭代求解的过程。

Step0:随机数据生成 
首先,显然先生成这样两组正态分布的数据。

import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
# sort all the data for later use
both_colours = np.sort(np.concatenate((red, blue))) # for later use...

我们可以输出结果看看,方便之后与估计的参数进行对比。

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

初始化的参数绘制的图像如下所示,看见这张图是不是有点似曾相识,在贝叶斯决策中我们就画过这样的图,判断是否属于某个类别的时候,分界面是一个点,比如下图中,红色和蓝色正态分布图的交叉点就是分界点,小于这个点就是属于红色,否则是蓝色。但是EM算法更加强,不仅能找出这个分界面,而且能估计出参数。

Step2: 计算似然函数

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

Step3:估计参数

def estimate_mean(data, weight):
    return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

**Step4: 迭代20次后结果 **

下图是迭代的过程绘制在图像上,可以看出拟合程度。

最终的结果图如下所示。分界点大概在4.2左右,参考最开始给出的那个红蓝分解点可以看出,估计还是比较准确的。

最下面是估计结果参数对比实际的参数的表格。

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032
9.5 参考文献

[1] 李航,《统计学习方法》 
[2] 周志华, 《机器学习》 
[3] 机器学习算法系列之一】EM算法实例分析 
[4] 从最大似然到EM算法浅解 
http://blog.csdn.net/zouxy09/article/details/8537620
[5] CS229 Lecture notes by Andrew Ng 
[6] What is an intuitive explanation of the Expectation Maximization technique? 

发表评论

电子邮件地址不会被公开。 必填项已用*标注