mean shift聚類算法的MATLAB程序
凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/
1. mean shift 簡介
mean shift, 寫的更符合國人的習慣,應該是mean of shift,也就是平均偏移量,或者偏移均值向量。在明確了含義之后,就可以開始如下的具體講解了。

1). 基本形式
其中 為
個樣本點,
,
為以
為中心的半徑為
的高維球體,表示有效區域,其中包含
個樣本點。其變形如下:
由此可以可知, 作為x的偏移均值向量,可用來對
進行更新,但這種更新有什么意義呢?通過簡單的二維樣本模擬,可以發現其傾向於向有效區域中樣本密度高(即概率密度大)的地方移動。
2). 改進形式
基本形式中隱含了在有效區域中對所有的樣本點一視同仁的假設,但這通常是不成立,最常見的就是隨着距離的增加,作用就越小,因此,就有了如下的改進形式:
其中 為核函數,
表示帶寬(嚴格來講因為帶寬矩陣,為對角矩陣,但通常對角元素取相等,故可表示為標量),
為樣本權重。
由此可對基本形式進行更為合理的表示,采用均勻核函數,從而達到統一表示:
2. mean shift 解釋
1). 數學推導
概率密度估計中,常用的方法有直方圖估計、K近鄰估計、核函數估計,其中核函數估計的表示如下:
其中 同樣表示核函數。對概率密度函數
求導如下:
令 ,其亦是核函數,進一步分解,有如下表示:
可以看出,其中第二項也是一種概率密度的核函數估計,將其表示為 ,第三項則為上文中的mean shift的改進形式,因此,可以改寫為:
接下來是兩種解釋,首先,求解概率密度局部極大值,令 ,由於
,故有:
這表示mean shift的本質是在求解概率密度局部極大值,即偏移均值向量讓目標點始終向概率密度極大點處移動。但當數據量非常大時,一次遍歷所有樣本點顯然不合適,故常選取目標點 附近的一個區域,進行貪心迭代,逐步收斂於概率密度極大值處;另一種更合理的解釋是,通過在核函數
中融合進一個均勻核函數來表示選取的有效區域,然后迭代直至收斂。
再者,從梯度上升的優化角度來講,有如下表示:
即偏移均值向量的作用等價於以概率密度為目標的具有自適應步長的梯度上升優化,其在概率密度較小的位置步長較大,當逼近局部極大點時,概率密度較大,因此步長較小,符合梯度優化中步長變化的需要。
由此,便對mean shift的含義及其合理性進行了解釋,也就不難理解為何mean shift具有強大的效果及適用性了。
2). 泛化拓展
進一步拓展,雖然一般形式的mean shift是由概率密度的核函數估計推導出來的,其核心是核函數,但由於其具有歸一化表示的性質,因此,理論上可以泛化為如下表示形式:
其中 確定偏移向量
的整體權重,可以任意選取,但必然需要具有一定的意義。顯然偏移均值向量會傾向於權重較大的樣本點,因此,從概率密度最大化的角度來看,
可以是
處概率密度的一種表示。
3. mean shift MATLAB程序
testMeanShift.m
clear clc profile on bandwidth = 1; %% 加載數據 data_load=dlmread('gauss_data.txt'); [~,dim]=size(data_load); data=data_load(:,1:dim-1); x=data'; %% 聚類 tic [clustCent,point2cluster,clustMembsCell] = MeanShiftCluster(x,bandwidth); % clustCent:聚類中心 D*K, point2cluster:聚類結果 類標簽, 1*N toc %% 作圖 numClust = length(clustMembsCell); figure(2),clf,hold on cVec = 'bgrcmykbgrcmykbgrcmykbgrcmyk';%, cVec = [cVec cVec]; for k = 1:min(numClust,length(cVec)) myMembers = clustMembsCell{k}; myClustCen = clustCent(:,k); plot(x(1,myMembers),x(2,myMembers),[cVec(k) '.']) plot(myClustCen(1),myClustCen(2),'o','MarkerEdgeColor','k','MarkerFaceColor',cVec(k), 'MarkerSize',10) end title(['no shifting, numClust:' int2str(numClust)])
MeanShiftCluster.m
function [clustCent,data2cluster,cluster2dataCell] = MeanShiftCluster(dataPts,bandWidth,plotFlag) %perform MeanShift Clustering of data using a flat kernel % % ---INPUT--- % dataPts - input data, (numDim x numPts) % bandWidth - is bandwidth parameter (scalar) % plotFlag - display output if 2 or 3 D (logical) % ---OUTPUT--- % clustCent - is locations of cluster centers (numDim x numClust) % data2cluster - for every data point which cluster it belongs to (numPts) % cluster2dataCell - for every cluster which points are in it (numClust) % % Bryan Feldman 02/24/06 % MeanShift first appears in % K. Funkunaga and L.D. Hosteler, "The Estimation of the Gradient of a % Density Function, with Applications in Pattern Recognition" %*** Check input **** if nargin < 2 error('no bandwidth specified') end if nargin < 3 plotFlag = true; plotFlag = false; end %**** Initialize stuff *** [numDim,numPts] = size(dataPts); numClust = 0; bandSq = bandWidth^2; initPtInds = 1:numPts; maxPos = max(dataPts,[],2); %biggest size in each dimension minPos = min(dataPts,[],2); %smallest size in each dimension boundBox = maxPos-minPos; %bounding box size sizeSpace = norm(boundBox); %indicator of size of data space stopThresh = 1e-3*bandWidth; %when mean has converged clustCent = []; %center of clust beenVisitedFlag = zeros(1,numPts); %track if a points been seen already numInitPts = numPts; %number of points to posibaly use as initilization points clusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership while numInitPts tempInd = ceil( (numInitPts-1e-6)*rand); %pick a random seed point stInd = initPtInds(tempInd); %use this point as start of mean myMean = dataPts(:,stInd); % intilize mean to this points location myMembers = []; % points that will get added to this cluster thisClusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership while 1 %loop untill convergence sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2); %dist squared from mean to all points still active inInds = find(sqDistToAll < bandSq); %points within bandWidth thisClusterVotes(inInds) = thisClusterVotes(inInds)+1; %add a vote for all the in points belonging to this cluster myOldMean = myMean; %save the old mean myMean = mean(dataPts(:,inInds),2); %compute the new mean myMembers = [myMembers inInds]; %add any point within bandWidth to the cluster beenVisitedFlag(myMembers) = 1; %mark that these points have been visited %*** plot stuff **** if plotFlag figure(1),clf,hold on if numDim == 2 plot(dataPts(1,:),dataPts(2,:),'.') plot(dataPts(1,myMembers),dataPts(2,myMembers),'ys') plot(myMean(1),myMean(2),'go') plot(myOldMean(1),myOldMean(2),'rd') pause end end %**** if mean doesn't move much stop this cluster *** if norm(myMean-myOldMean) < stopThresh %check for merge posibilities mergeWith = 0; for cN = 1:numClust distToOther = norm(myMean-clustCent(:,cN)); %distance from posible new clust max to old clust max if distToOther < bandWidth/2 %if its within bandwidth/2 merge new and old mergeWith = cN; break; end end if mergeWith > 0 % something to merge clustCent(:,mergeWith) = 0.5*(myMean+clustCent(:,mergeWith)); %record the max as the mean of the two merged (I know biased twoards new ones) %clustMembsCell{mergeWith} = unique([clustMembsCell{mergeWith} myMembers]); %record which points inside clusterVotes(mergeWith,:) = clusterVotes(mergeWith,:) + thisClusterVotes; %add these votes to the merged cluster else %its a new cluster numClust = numClust+1; %increment clusters clustCent(:,numClust) = myMean; %record the mean %clustMembsCell{numClust} = myMembers; %store my members clusterVotes(numClust,:) = thisClusterVotes; end break; end end initPtInds = find(beenVisitedFlag == 0); %we can initialize with any of the points not yet visited numInitPts = length(initPtInds); %number of active points in set end [val,data2cluster] = max(clusterVotes,[],1); %a point belongs to the cluster with the most votes %*** If they want the cluster2data cell find it for them if nargout > 2 cluster2dataCell = cell(numClust,1); for cN = 1:numClust myMembers = find(data2cluster == cN); cluster2dataCell{cN} = myMembers; end end
數據見:MATLAB中“fitgmdist”的用法及其GMM聚類算法,保存為gauss_data.txt文件,數據最后一列是類標簽。
4. 結果
注意:聚類結果與核函數中的參數帶寬bandwidth有很大關系,視具體數據而定。