6. EM算法-高斯混合模型GMM+Lasso詳細代碼實現


1. 前言

我們之前有介紹過4. EM算法-高斯混合模型GMM詳細代碼實現,在那片博文里面把GMM說涉及到的過程,可能會遇到的問題,基本講了。今天我們升級下,主要一起解析下EM算法中GMM(搞事混合模型)帶懲罰項的詳細代碼實現。

2. 原理

由於我們的極大似然公式加上了懲罰項,所以整個推算的過程在幾個地方需要修改下。

在帶penality的GMM中,我們假設協方差是一個對角矩陣,這樣的話,我們計算高斯密度函數的時候,只需要把樣本各個維度與對應的\(\mu_k\)\(\sigma_k\)計算一維高斯分布,再相加即可。不需要通過多維高斯進行計算,也不需要協方差矩陣是半正定的要求。

我們給上面的(1)式加入一個懲罰項,

\[\lambda\sum_{k=1}^K\sum_{j=1}^P\frac{|\mu_k-\bar{x}_j|}{s_j} \]

其中的\(P\)是樣本的維度。\(\bar{x}_j\)表示每個維度的平均值,\(s_j\)表示每個維度的標准差。這個penality是一個L1范式,對\(\mu_k\)進行約束。

加入penality后(1)變為

\[L(\theta,\theta^{(j)})=\sum_{k=1}^Kn_k[log\pi_k-\frac{1}{2}(log(\boldsymbol{\Sigma_k})+\frac{{(x_i-\boldsymbol{\mu}_k})^2}{\boldsymbol{\Sigma}_k})] - \lambda\sum_{k=1}^K\sum_{j=1}^P\frac{|\mu_k-\bar{x}_j|}{s_j} \]

這里需要注意的一點是,因為penality有一個絕對值,所以在對\(\mu_k\)求導的時候,需要分情況。於是(2)變成了

\[\mu_k=\frac{1}{n_k}\sum_{i=1}^N\gamma_{ik}x_i \]

\[\mu_k= \left \{\begin{array}{cc} \frac{1}{n_k}(\sum_{i=1}^N\gamma_{ik}x_i - \frac{\lambda\sigma^2}{s_j}), & \mu_k >= \bar{x}_j\\ \frac{1}{n_k}(\sum_{i=1}^N\gamma_{ik}x_i + \frac{\lambda\sigma^2}{s_j}), & \mu_k < \bar{x}_j \end{array}\right. \]

3. 算法實現

  • 和不帶懲罰項的GMM不同的是,我們GMM+LASSO的計算高斯密度函數有所變化。
#計算高斯密度概率函數,樣本的高斯概率密度函數,其實就是每個一維mu,sigma的高斯的和
def log_prob(self, X, mu, sigma):
    N, D = X.shape
    logRes = np.zeros(N)
    for i in range(N):
        a = norm.logpdf(X[i,:], loc=mu, scale=sigma)
        logRes[i] = np.sum(a)
    return logRes
  • 在m-step中計算\(\mu_{k+1}\)的公式需要變化,先通過比較\(\mu_{kj}\)\(means_{kj}\)的大小,來確定絕對值shift的符號。
def m_step(self, step):
    gammaNorm = np.array(np.sum(self.gamma, axis=0)).reshape(self.K, 1)
    self.alpha = gammaNorm / np.sum(gammaNorm)
    for k in range(self.K):
        Nk = gammaNorm[k]
        if Nk == 0:
            continue
        for j in range(self.D):
            if step >= self.beginPenaltyTime:
                # 算出penality的偏移量shift,通過當前維度的mu和樣本均值比較,確定shift的符號,相當於把lasso的絕對值拆開了
                shift = np.square(self.sigma[k, j]) * self.penalty / (self.std[j] * Nk)
                if self.mu[k, j] >= self.means[j]:
                    shift = shift
                else:
                    shift = -shift
            else:
                shift = 0
            self.mu[k, j] = np.dot(self.gamma[:, k].T, self.X[:, j]) / Nk - shift
            self.sigma[k, j] = np.sqrt(np.sum(np.multiply(self.gamma[:, k], np.square(self.X[:, j] - self.mu[k, j]))) / Nk)
  • 最后需要修改loglikelihood的計算公式
def GMM_EM(self):
    self.init_paras()
    for i in range(self.times):
        #m step
        self.m_step(i)
        # e step
        logGammaNorm, self.gamma= self.e_step(self.X)
        #loglikelihood
        loglike = self.logLikelihood(logGammaNorm)
        #penalty
        pen = 0
        if i >= self.beginPenaltyTime:
            for j in range(self.D):
                pen += self.penalty * np.sum(abs(self.mu[:,j] - self.means[j])) / self.std[j]

        # print("step = %s, alpha = %s, loglike = %s"%(i, [round(p[0], 5) for p in self.alpha.tolist()], round(loglike - pen, 5)))
        # if abs(self.loglike - loglike) < self.tol:
        #     break
        # else:

        self.loglike = loglike - pen

4. GMM算法實現結果

用我實現的GMM+LASSO算法,對多個penality進行計算,選出loglikelihood最大的k和penality,與sklearn的結果比較。

fileName = amix1-est.dat, k = 2, penalty = 0 alpha = [0.52838, 0.47162], loglike = -693.34677
fileName = amix1-est.dat, k = 2, penalty = 0 alpha = [0.52838, 0.47162], loglike = -693.34677
fileName = amix1-est.dat, k = 2, penalty = 1 alpha = [0.52789, 0.47211], loglike = -695.26835
fileName = amix1-est.dat, k = 2, penalty = 1 alpha = [0.52789, 0.47211], loglike = -695.26835
fileName = amix1-est.dat, k = 2, penalty = 2 alpha = [0.52736, 0.47264], loglike = -697.17009
fileName = amix1-est.dat, k = 2, penalty = 2 alpha = [0.52736, 0.47264], loglike = -697.17009
myself GMM alpha = [0.52838, 0.47162], loglikelihood = -693.34677, bestP = 0
sklearn GMM alpha = [0.53372, 0.46628], loglikelihood = -176.73112
succ = 299/300
succ = 0.9966666666666667
[0 1 0 0 1 1 0 1 1 1 0 0 1 0 0 1 0 0 0 1]
[0 1 0 0 1 0 0 1 1 1 0 0 1 0 0 1 0 0 0 1]
fileName = amix1-tst.dat, loglike = -2389.1852339407087
fileName = amix1-val.dat, loglike = -358.1157431278091
fileName = amix2-est.dat, k = 2, penalty = 0 alpha = [0.56, 0.44], loglike = 53804.54265
fileName = amix2-est.dat, k = 2, penalty = 0 alpha = [0.82, 0.18], loglike = 24902.5522
fileName = amix2-est.dat, k = 2, penalty = 1 alpha = [0.82, 0.18], loglike = 23902.65183
fileName = amix2-est.dat, k = 2, penalty = 1 alpha = [0.56, 0.44], loglike = 52929.96459
fileName = amix2-est.dat, k = 2, penalty = 2 alpha = [0.82, 0.18], loglike = 22907.40397
fileName = amix2-est.dat, k = 2, penalty = 2 alpha = [0.82, 0.18], loglike = 22907.40397
myself GMM alpha = [0.56, 0.44], loglikelihood = 53804.54265, bestP = 0
sklearn GMM alpha = [0.56217, 0.43783], loglikelihood = 11738677.90164
succ = 200/200
succ = 1.0
[0 1 0 0 1 0 1 1 0 0 0 1 1 0 1 0 1 1 0 1]
[0 1 0 0 1 0 1 1 0 0 0 1 1 0 1 0 1 1 0 1]
fileName = amix2-tst.dat, loglike = 51502.878096147084
fileName = amix2-val.dat, loglike = 6071.217012747491
fileName = golub-est.dat, k = 2, penalty = 0 alpha = [0.575, 0.425], loglike = -24790.19895
fileName = golub-est.dat, k = 2, penalty = 0 alpha = [0.525, 0.475], loglike = -24440.82743
fileName = golub-est.dat, k = 2, penalty = 1 alpha = [0.55, 0.45], loglike = -25582.27485
fileName = golub-est.dat, k = 2, penalty = 1 alpha = [0.6, 0.4], loglike = -26137.97508
fileName = golub-est.dat, k = 2, penalty = 2 alpha = [0.55, 0.45], loglike = -26686.02411
fileName = golub-est.dat, k = 2, penalty = 2 alpha = [0.55, 0.45], loglike = -26941.68964
myself GMM alpha = [0.525, 0.475], loglikelihood = -24440.82743, bestP = 0
sklearn GMM alpha = [0.5119, 0.4881], loglikelihood = 13627728.10766
succ = 29/40
succ = 0.725
[0 1 0 1 0 1 0 1 0 1 0 0 0 0 0 1 1 1 0 1]
[0 1 0 1 1 1 0 0 1 1 0 1 0 1 0 0 1 1 0 0]
fileName = golub-tst.dat, loglike = -12949.606698037718
fileName = golub-val.dat, loglike = -11131.35137056415

5. 總結

通過一番改造,實現了GMM+LASSO的代碼,如果讀者有什么好的改進方法,或者我有什么錯誤的地方,希望多多指教。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM