EM算法與GMM
Hongliang He 2014年4月 hehongliang168168@126.com
注:本文主要參考Andrew Ng的Lecture notes 8,並結合自己的理解和擴展完成。
GMM簡介
GMM(Gaussian mixture model) 混合高斯模型在機器學習、計算機視覺等領域有着廣泛的應用。其典型的應用有概率密度估計、背景建模、聚類等。
圖1 GMM用於聚類 圖2 GMM用於概率密度估計 圖3 GMM用於背景建模
我們以GMM聚類為例子進行討論。如圖1所示,假設我們有m個點,其坐標數據為{,…}。假設m個數據分別屬於k個類別,且不知道每個點屬於哪一個類。倘若假設每個類的分布函數都是高斯分布,那我們該如何求得每個點所屬的類別?以及每個類別的概率分布函數(概率密度估計)?我們先嘗試最大似然估計。
上式中是當前m個數據出現的概率,我們要將它最大化;是出現的概率;是指第z個類;u和分別指第z個類的均值和方差;為其他的參數。為計算方便,對上式兩邊取對數,得到似然函數。
上說道,GMM的表達式為k個高斯分布的疊加,所以有
為類出現的先驗概率。令j=,所以此時的似然函數可以寫為
上式中x和z為自變量;為需要估計的參數。為高斯分布,我們可以寫出解析式,但是的形式是未知的。所以如果我們不能直接對求偏導取極值。考慮到z是不能直接觀測到的,我們稱為隱藏變量(latent variable)。為了求解
我們引入EM算法(Expectation-Maximization)。我們從Jensen不等式開始討論EM算法。
Jensen不等式
若實函數存在二階導且有,則為凸函數(convex function)。的值域為,則對於
有以下不等式成立:
此不等式的幾何解釋如下
需要說明的是,若則不等式的方向取反。對上式進行推廣,便可得到Jensen不等式(Jensen's Inequality)。倘若有為凸函數,且
則有
此結果可由數學歸納法得到,在這里不做詳細的描述。值得注意的是,如果Jensen不等式中的,而且把看做概率密度,則有
上式成立的依據是,,為概率密度時,f(E(x))=且。在后續的EM算法推導中,會連續多次應用到Jensen不等式的性質。
EM算法
現在重新考慮之前的似然函數
直接對上式進行最大化求解會比較困難,所以我們考慮進行一定的變通。假設是某種概率密度函數,有且。現在對的表達式進行一定得處理,先乘以一個再除以一個,有
我們把看做是的函數; 為概率密度,則有
考慮到log函數為凹函數,利用Jensen不等式有
此時我們找到了的一個下界。而且這個下界的選取隨着的不同而不同。即我們得到了一組下界。用下圖來簡單描述
圖3 選擇不同的得到不同的下界
我們的目的是最大化,如果我們不斷的取的最優下界,再優化最優下界,等到算法收斂就得到了局部最大值。所以我們先取得的最優下界。上式在等號成立時取得最優下界。根據Jensen不等式的性質,取得等號時的條件有
c是不依賴於的常數。此時如果選取就可使得上式成立。又考慮到=1,所以我們可以取
所以取后驗概率的時候是最優下界。如果此時在下界的基礎上優化參數使其最大化,則可進一步抬高。如此循環往復的進行:取最優化下界;優化下界,便是EM算法的做法。接下來正式給出EM算法的步驟:
算法開始
E-step:取似然函數的最優下界,對於每個訓練樣本計算。
M-step:優化下界,即求取。
判斷是否成立,若成立則算法結束。是設定的算法收斂時的增量。
這就是一個不斷取最優下界,抬高下界的過程。用下圖簡單的表示一個迭代過程:
圖4 EM算法的幾何解釋
我們可以這樣解釋:E-step就是取的最優下界,此處是。在M-step,我們優化下界,通過調整使得取得局部最優值。由於Jensen不等式始終成立,始終大於等於下界,所以的值從變為實現上升。那么這樣的迭代是否是收斂的呢?
假設在t時刻的參數為此時的似然函數值為。接下來進行EM算法迭代,在E-step
第二步利用了Jensen不等式。在M-step
所以有
上式第二步中再次用到Jensen不等式。所以似然函數會一直單調遞增,直到到達局部最優值。利用圖4來解釋的話我們可以這樣看:在E-step我們選取了最優下界,此時=;在M-step我們優化得到;最后Jensen不等式一直都成立,所以有=,即。
GMM的訓練
對於GMM,其表達式為
是每個gauss分量的權重。在E-step有
對於M-step
其中需要優化的參數為均值分別對其求偏導。
令
解出
這便是第l個高斯分量均值在M-step的更新公式。
對於協方差矩陣
考慮到
且有
所以有
等價於
為對稱陣,,所以有
解出協方差矩陣的更新公式為
以上便是協方差矩陣
對於每個gauss分量的權重(或者說是先驗概率),考慮到有等式約束
應用Lagrange乘子法
所以有
考慮到
聯立方程可解得
這便是的更新公式。
總結啟發
-
EM 算法適用於似然函數中具有隱藏變量的估計問題。
-
創造下界的想法非常精妙,應該有廣泛的應用前景。
-
Jensen 不等式在不等式證明方面有着廣泛應用。
GMM的簡單應用
接下來簡單討論GMM在圖像分割中的應用。以圖像中每個像素的顏色信息作為特征進行聚類進而達到圖像分割的目的。我們同時拿k-means算法作為對比。
-
K-means 和 GMM 用於圖像分割由於只考慮了像素的顏色信息,沒有考慮空間信息導致其對於復雜背景的效果很差。對於簡單背景和前景的顏色分布都比較柔和的情況有較好的效果。
-
K-means 初始值的選擇非常重要。不好的初始值經常會造成較差的聚類效果。
-
應用 GMM 時,先將 3 通道彩色圖像轉換為了灰度圖。原因是原始的 3 個通道數據存在很強的相關性,導致協方差矩陣不可逆。
-
聚類 ( 分割 ) 時需要手動確定類別的數量。類的數量對於聚類效果也有很大的影響。
Matlab實現
根據以上推導,可以很容易實現EM算法估計GMM參數。現以1維數據2個高斯混合概率密度估計作為實例,詳細代碼如下所示。
% fitting_a_gmm.m
% EM算法簡單實現
% Hongliang He 2014/03
clear
close all
clc
% generate data
len1 = 1000;
len2 = fix(len1 * 1.5);
data = [normrnd(0, 1, [1 len1]) normrnd(4, 2, [1 len2])] + 0.1*rand([1 len1+len2]);
data_len = length(data);
% use EM algroithm to estimate the parameters
ite_cnt = 100000; % maximum iterations
max_err = 1e-5; % 迭代停止條件
% soft boundary EM algorithm
z0 = 0.5; % prior probability
z1 = 1 - z0;
u = mean(data);
u0 = 1.2 * u;
u1 = 0.8 * u;
sigma0 = 1;
sigma1 = 1;
itetation = 0;
while( itetation < ite_cnt )
% init papameters
w0 = zeros(1, data_len); % Qi, postprior
w1 = zeros(1, data_len);
% E-step, update Qi/w to get a tight lower bound
for k1=1:data_len
p0 = z0 * gauss(data(k1), u0, sigma0);
p1 = z1 * gauss(data(k1), u1, sigma1);
p = p0 / (p0 + p1);
if p0 == 0 && p1 == 0
%p = w0(k1);
dist0 = (data(k1)-u0).^2;
dist1 = (data(k1)-u1).^2;
if dist0 > dist1
p = w0(k1) + 0.01;
elseif dist0 == dist1
else
p = w0(k1) - 0.01;
end
end
if p > 1
p = 1;
elseif p < 0
p = 0;
end
w0(k1) = p; % postprior
w1(k1) = 1 - w0(k1);
end
% record the pre-value
old_u0 = u0;
old_u1 = u1;
old_sigma0 = sigma0;
old_sigma1 = sigma1;
% M-step, maximize the lower bound
u0 = sum(w0 .* data) / sum(w0);
u1 = sum(w1 .* data) / sum(w1);
sigma0 = sqrt( sum(w0 .* (data - u0).^2) / sum(w0));
sigma1 = sqrt( sum(w1 .* (data - u1).^2) / sum(w1));
z0 = sum(w0) / data_len;
z1 = sum(w1) / data_len;
% is convergance
if mod(itetation, 10) == 0
sprintf('%d: u0=%f,d0=%f u1=%f,d1=%f\n',itetation, …
u0,sigma0,u1,sigma1)
end
d_u0 = abs(u0 - old_u0);
d_u1 = abs(u1 - old_u1);
d_sigma0 = abs(sigma0 - old_sigma0);
d_sigma1 = abs(sigma1 - old_sigma1);
% 迭代停止判斷
if d_u0 < max_err && d_u1 < max_err && …
d_sigma0 < max_err && d_sigma1 < max_err
clc
sprintf('ite = %d, final value is', itetation)
sprintf('u0=%f,d0=%f u1=%f,d1=%f\n', u0,sigma0,u1,sigma1)
break;
end
itetation = itetation + 1;
end
% compare
my_hist(data, 20);
hold on;
mi = min(data);
mx = max(data);
t = linspace(mi, mx, 100);
y = z0*gauss(t, u0, sigma0) + z1*gauss(t, u1, sigma1);
plot(t, y, 'r', 'linewidth', 5);
% gauss.m
% 1維高斯函數
% Hongliang He 2014/03
function y = gauss(x, u, sigma)
y = exp( -0.5*(x-u).^2/sigma.^2 ) ./ (sqrt(2*pi)*sigma);
end
% my_hist.m
% 用直方圖估計概率密度
% Hongliang He 2013/03
function my_hist(data, cnt)
dat_len = length(data);
if dat_len < cnt*5
error('There are not enough data!\n')
end
mi = min(data);
ma = max(data);
if ma <= mi
error('sorry, there is only one type of data\n')
end
dt = (ma - mi) / cnt;
t = linspace(mi, ma, cnt);
for k1=1:cnt-1
y(k1) = sum( data >= t(k1) & data < t(k1+1) );
end
y = y ./ dat_len / dt;
t = t + 0.5*dt;
bar(t(1:cnt-1), y);
%stem(t(1:cnt-1), y)
end
最終運行結果:
EM算法最終結果