mean shift聚類算法的MATLAB程序


mean shift聚類算法的MATLAB程序

凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/

1. mean shift 簡介

    mean shift, 寫的更符合國人的習慣,應該是mean of shift,也就是平均偏移量,或者偏移均值向量。在明確了含義之后,就可以開始如下的具體講解了。

1). 基本形式

[公式]

其中 [公式][公式] 個樣本點, [公式][公式] 為以 [公式] 為中心的半徑為 [公式] 的高維球體,表示有效區域,其中包含 [公式] 個樣本點。其變形如下:

[公式]

    由此可以可知, [公式] 作為x的偏移均值向量,可用來對 [公式] 進行更新,但這種更新有什么意義呢?通過簡單的二維樣本模擬,可以發現其傾向於向有效區域中樣本密度高(即概率密度大)的地方移動。

2). 改進形式

    基本形式中隱含了在有效區域中對所有的樣本點一視同仁的假設,但這通常是不成立,最常見的就是隨着距離的增加,作用就越小,因此,就有了如下的改進形式:

[公式]

其中 [公式] 為核函數, [公式] 表示帶寬(嚴格來講因為帶寬矩陣,為對角矩陣,但通常對角元素取相等,故可表示為標量), [公式] 為樣本權重。

    由此可對基本形式進行更為合理的表示,采用均勻核函數,從而達到統一表示:

[公式]

2. mean shift 解釋

1). 數學推導

    概率密度估計中,常用的方法有直方圖估計、K近鄰估計、核函數估計,其中核函數估計的表示如下:

[公式]

 

其中 [公式] 同樣表示核函數。對概率密度函數 [公式] 求導如下:

[公式]

    令 [公式] ,其亦是核函數,進一步分解,有如下表示:

[公式]

    可以看出,其中第二項也是一種概率密度的核函數估計,將其表示為 [公式] ,第三項則為上文中的mean shift的改進形式,因此,可以改寫為:

[公式]

    接下來是兩種解釋,首先,求解概率密度局部極大值,令 [公式] ,由於 [公式] ,故有:

[公式]

    這表示mean shift的本質是在求解概率密度局部極大值,即偏移均值向量讓目標點始終向概率密度極大點處移動。但當數據量非常大時,一次遍歷所有樣本點顯然不合適,故常選取目標點 [公式] 附近的一個區域,進行貪心迭代,逐步收斂於概率密度極大值處;另一種更合理的解釋是,通過在核函數 [公式] 中融合進一個均勻核函數來表示選取的有效區域,然后迭代直至收斂。

    再者,從梯度上升的優化角度來講,有如下表示:

[公式]

即偏移均值向量的作用等價於以概率密度為目標的具有自適應步長的梯度上升優化,其在概率密度較小的位置步長較大,當逼近局部極大點時,概率密度較大,因此步長較小,符合梯度優化中步長變化的需要。

    由此,便對mean shift的含義及其合理性進行了解釋,也就不難理解為何mean shift具有強大的效果及適用性了。

2). 泛化拓展

    進一步拓展,雖然一般形式的mean shift是由概率密度的核函數估計推導出來的,其核心是核函數,但由於其具有歸一化表示的性質,因此,理論上可以泛化為如下表示形式:

[公式]

其中 [公式]確定偏移向量 [公式] 的整體權重,可以任意選取,但必然需要具有一定的意義。顯然偏移均值向量會傾向於權重較大的樣本點,因此,從概率密度最大化的角度來看,[公式] 可以是 [公式] 處概率密度的一種表示。

3. mean shift MATLAB程序

testMeanShift.m

clear
clc
profile on

bandwidth = 1;
%% 加載數據
data_load=dlmread('gauss_data.txt');
[~,dim]=size(data_load);
data=data_load(:,1:dim-1);
x=data';
%% 聚類
tic
[clustCent,point2cluster,clustMembsCell] = MeanShiftCluster(x,bandwidth);
% clustCent:聚類中心 D*K, point2cluster:聚類結果 類標簽, 1*N
toc
%% 作圖
numClust = length(clustMembsCell);
figure(2),clf,hold on
cVec = 'bgrcmykbgrcmykbgrcmykbgrcmyk';%, cVec = [cVec cVec];
for k = 1:min(numClust,length(cVec))
    myMembers = clustMembsCell{k};
    myClustCen = clustCent(:,k);
    plot(x(1,myMembers),x(2,myMembers),[cVec(k) '.'])
    plot(myClustCen(1),myClustCen(2),'o','MarkerEdgeColor','k','MarkerFaceColor',cVec(k), 'MarkerSize',10)
end
title(['no shifting, numClust:' int2str(numClust)])

MeanShiftCluster.m

function [clustCent,data2cluster,cluster2dataCell] = MeanShiftCluster(dataPts,bandWidth,plotFlag)
%perform MeanShift Clustering of data using a flat kernel
%
% ---INPUT---
% dataPts           - input data, (numDim x numPts)
% bandWidth         - is bandwidth parameter (scalar)
% plotFlag          - display output if 2 or 3 D    (logical)
% ---OUTPUT---
% clustCent         - is locations of cluster centers (numDim x numClust)
% data2cluster      - for every data point which cluster it belongs to (numPts)
% cluster2dataCell  - for every cluster which points are in it (numClust)
% 
% Bryan Feldman 02/24/06
% MeanShift first appears in
% K. Funkunaga and L.D. Hosteler, "The Estimation of the Gradient of a
% Density Function, with Applications in Pattern Recognition"


%*** Check input ****
if nargin < 2
    error('no bandwidth specified')
end

if nargin < 3
    plotFlag = true;
    plotFlag = false;
end

%**** Initialize stuff ***
[numDim,numPts] = size(dataPts);
numClust        = 0;
bandSq          = bandWidth^2;
initPtInds      = 1:numPts;
maxPos          = max(dataPts,[],2);                    %biggest size in each dimension
minPos          = min(dataPts,[],2);                    %smallest size in each dimension
boundBox        = maxPos-minPos;                        %bounding box size
sizeSpace       = norm(boundBox);                       %indicator of size of data space
stopThresh      = 1e-3*bandWidth;                       %when mean has converged
clustCent       = [];                                   %center of clust
beenVisitedFlag = zeros(1,numPts);                      %track if a points been seen already
numInitPts      = numPts;                               %number of points to posibaly use as initilization points
clusterVotes    = zeros(1,numPts);                      %used to resolve conflicts on cluster membership


while numInitPts

    tempInd         = ceil( (numInitPts-1e-6)*rand);        %pick a random seed point
    stInd           = initPtInds(tempInd);                  %use this point as start of mean
    myMean          = dataPts(:,stInd);                     % intilize mean to this points location
    myMembers       = [];                                   % points that will get added to this cluster                          
    thisClusterVotes    = zeros(1,numPts);                  %used to resolve conflicts on cluster membership

    while 1     %loop untill convergence
        
        sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2);    %dist squared from mean to all points still active
        inInds      = find(sqDistToAll < bandSq);                     %points within bandWidth
        thisClusterVotes(inInds) = thisClusterVotes(inInds)+1;        %add a vote for all the in points belonging to this cluster
        
        
        myOldMean   = myMean;                                   %save the old mean
        myMean      = mean(dataPts(:,inInds),2);                %compute the new mean
        myMembers   = [myMembers inInds];                       %add any point within bandWidth to the cluster
        beenVisitedFlag(myMembers) = 1;                         %mark that these points have been visited
        
        %*** plot stuff ****
        if plotFlag
            figure(1),clf,hold on
            if numDim == 2
                plot(dataPts(1,:),dataPts(2,:),'.')
                plot(dataPts(1,myMembers),dataPts(2,myMembers),'ys')
                plot(myMean(1),myMean(2),'go')
                plot(myOldMean(1),myOldMean(2),'rd')
                pause
            end
        end

        %**** if mean doesn't move much stop this cluster ***
        if norm(myMean-myOldMean) < stopThresh
            
            %check for merge posibilities
            mergeWith = 0;
            for cN = 1:numClust
                distToOther = norm(myMean-clustCent(:,cN));     %distance from posible new clust max to old clust max
                if distToOther < bandWidth/2                    %if its within bandwidth/2 merge new and old
                    mergeWith = cN;
                    break;
                end
            end
            
            
            if mergeWith > 0    % something to merge
                clustCent(:,mergeWith)       = 0.5*(myMean+clustCent(:,mergeWith));              %record the max as the mean of the two merged (I know biased twoards new ones)
                %clustMembsCell{mergeWith}    = unique([clustMembsCell{mergeWith} myMembers]);   %record which points inside 
                clusterVotes(mergeWith,:)    = clusterVotes(mergeWith,:) + thisClusterVotes;     %add these votes to the merged cluster
            else    %its a new cluster
                numClust                    = numClust+1;                    %increment clusters
                clustCent(:,numClust)       = myMean;                        %record the mean  
                %clustMembsCell{numClust}    = myMembers;                    %store my members
                clusterVotes(numClust,:)    = thisClusterVotes;
            end

            break;
        end

    end
    
    
    initPtInds      = find(beenVisitedFlag == 0);           %we can initialize with any of the points not yet visited
    numInitPts      = length(initPtInds);                   %number of active points in set

end

[val,data2cluster] = max(clusterVotes,[],1);                %a point belongs to the cluster with the most votes

%*** If they want the cluster2data cell find it for them
if nargout > 2
    cluster2dataCell = cell(numClust,1);
    for cN = 1:numClust
        myMembers = find(data2cluster == cN);
        cluster2dataCell{cN} = myMembers;
    end
end

數據見:MATLAB中“fitgmdist”的用法及其GMM聚類算法,保存為gauss_data.txt文件,數據最后一列是類標簽。

4. 結果

    注意:聚類結果與核函數中的參數帶寬bandwidth有很大關系,視具體數據而定。

5. 參考文獻

[1] 均值偏移( mean shift )?

[2] Mean Shift Clustering

[3] 簡單易學的機器學習算法——Mean Shift聚類算法


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM