K-SVD字典學習及其實現(Python)


算法思想

算法求解思路為交替迭代的進行稀疏編碼和字典更新兩個步驟. K-SVD在構建字典步驟中,K-SVD不僅僅將原子依次更新,對於原子對應的稀疏矩陣中行向量也依次進行了修正. 不像MOP,K-SVD不需要對矩陣求逆,而是利用SVD數學分析方法得到了一個新的原子和修正的系數向量.

固定系數矩陣X和字典矩陣D,字典的第\(k\)個原子為\(d_k\),同時\(d_k\)對應的稀疏矩陣為\(X\)中的第\(k\)個行向量\(x^k_T\). 假設當前更新進行到原子\(d_k\),樣本矩陣和字典逼近的誤差為:

\[\|Y - DX\|^2_F = \|Y - \sum\limits^K_{j=1}d_jx^j_T\|^2_F = \|(Y - \sum\limits_{j\neq k}d_jx^j_T) - d_kx^j_T\|^2_F = \|E_k -d_kx^k_T\|^2_F \]

在得到當前誤差矩陣\(E_k\)后,需要調整\(d_k\)\(X^k_T\),使其乘積與\(E_k\)的誤差盡可能的小.

如果直接對\(d_k\)\(X^k_T\)進行更新,可能導致\(x^k_T\)不稀疏. 所以可以先把原有向量\(x^k_T\)中零元素去除,保留非零項,構成向量\(x^k_R\),然后從誤差矩陣\(E_k\)中取出相應的列向量,構成矩陣\(E^R_k\). 對\(E^R_k\)進行SVD(Singular Value Decomposition)分解,有\(E^R_k = U\Delta V^T\),由\(U\)的第一列更新\(d_k\),由\(V\)的第一列乘以\(\Delta (1,1)\)所得結果更新\(x^k_R\).

Python實現

import numpy as np
from sklearn import linear_model
import scipy.misc
from matplotlib import pyplot as plt


class KSVD(object):
    def __init__(self, n_components, max_iter=30, tol=1e-6,
                 n_nonzero_coefs=None):
        """
        稀疏模型Y = DX,Y為樣本矩陣,使用KSVD動態更新字典矩陣D和稀疏矩陣X
        :param n_components: 字典所含原子個數(字典的列數)
        :param max_iter: 最大迭代次數
        :param tol: 稀疏表示結果的容差
        :param n_nonzero_coefs: 稀疏度
        """
        self.dictionary = None
        self.sparsecode = None
        self.max_iter = max_iter
        self.tol = tol
        self.n_components = n_components
        self.n_nonzero_coefs = n_nonzero_coefs

    def _initialize(self, y):
        """
        初始化字典矩陣
        """
        u, s, v = np.linalg.svd(y)
        self.dictionary = u[:, :self.n_components]

    def _update_dict(self, y, d, x):
        """
        使用KSVD更新字典的過程
        """
        for i in range(self.n_components):
            index = np.nonzero(x[i, :])[0]
            if len(index) == 0:
                continue

            d[:, i] = 0
            r = (y - np.dot(d, x))[:, index]
            u, s, v = np.linalg.svd(r, full_matrices=False)
            d[:, i] = u[:, 0].T
            x[i, index] = s[0] * v[0, :]
        return d, x

    def fit(self, y):
        """
        KSVD迭代過程
        """
        self._initialize(y)
        for i in range(self.max_iter):
            x = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs)
            e = np.linalg.norm(y - np.dot(self.dictionary, x))
            if e < self.tol:
                break
            self._update_dict(y, self.dictionary, x)

        self.sparsecode = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs)
        return self.dictionary, self.sparsecode


if __name__ == '__main__':
    im_ascent = scipy.misc.ascent().astype(np.float)
    ksvd = KSVD(300)
    dictionary, sparsecode = ksvd.fit(im_ascent)
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(im_ascent)
    plt.subplot(1, 2, 2)
    plt.imshow(dictionary.dot(sparsecode))
    plt.show()

運行結果:
KSVD字典學習結果


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM