Meanshift均值漂移算法


 
 
通俗理解Meanshift均值漂移算法 
Meanshift車手?? 漂移?? 秋名山???   不,不,他是一組算法,  今天我就帶大家來了解一下機器學習中的Meanshift均值漂移.
Meanshift算法他的本質是一個迭代的過程 , 我先給大家講一下他的底層原理
 
 
1)概述
Mean-shift(均值遷移)的基本思想:在數據集中選定一個點,然后以這個點為圓心,r為半徑,畫一個圓(二維下是圓),求出這個點到所有點的向量的平均值,而圓心與向量均值的和為新的圓心,然后迭代此過程,直到滿足一點的條件結束。
后來在此基礎上加入了 核函數 和 權重系數 ,使得Mean-shift 算法開始流行起來。目前它在聚類、圖像平滑、分割、跟蹤等方面有着廣泛的應用。
 
2) 圖解過程
為了方便大家理解,借用下幾張圖來說明Mean-shift的基本過程。
第一張圖有一個子中心點,她向四周最近的點開始尋找,找到圓心與向量均值的和為新的圓心,然后依次循環,直到滿足條件,則不會再尋找其他圓心點

 

3)Mean-shift 算法函數
a)核心函數:sklearn.cluster.MeanShift(核函數:RBF核函數)
由上圖可知,圓心(或種子)的確定和半徑(或帶寬)的選擇,是影響算法效率的兩個主要因素。所以在sklearn.cluster.MeanShift中重點說明了這兩個參數的設定問題。
b)主要參數
bandwidth :半徑(或帶寬),float型。如果沒有給出,則使用sklearn.cluster.estimate_bandwidth計算出半徑(帶寬).(可選)
seeds :圓心(或種子),數組類型,即初始化的圓心。(可選)
bin_seeding :布爾值。如果為真,初始內核位置不是所有點的位置,而是點的離散版本的位置,其中點被分類到其粗糙度對應於帶寬的網格上。將此選項設置為True將加速算法,因為較少的種子將被初始化。默認值:False.如果種子參數(seeds)不為None則忽略。
c)主要屬性
cluster_centers_ : 數組類型。計算出的聚類中心的坐標。

 

labels_ :數組類型。每個數據點的分類標簽。
 
4)代碼詳解  這里用到的是一組貝葉斯數據
 
#分割數據集,拆分數據

#坐標軸負一問題
plt.rcParams['axes.unicode_minus'] =False
#分割數據集
from sklearn.model_selection import train_test_split
data=pd.read_csv('./貝葉斯.csv',header=None)
print(data.shape) #顯示幾行幾列

#拆分數據
dataset_X,dataset_y =data.iloc[:,:-1],data.iloc[:,-1]
# print(dataset_X.head())

## 將pandas轉為np.ndarray 可以用dataset = df.as_matrix()
dataset_X =dataset_X.values
dataset_y =dataset_y.values

#估算帶寬
from sklearn.cluster import estimate_bandwidth,MeanShift
# estimate_bandwidth有估計帶寬的意思 n_clusters聚類的個數 quantile分位數,分位點
bandwidth = estimate_bandwidth(dataset_X,quantile=0.1,n_samples=len(dataset_X))
#打印出帶寬
print(bandwidth).

#初始化聚類模型 bandwidth:帶寬 bin_seeding網格化數據點(加速模型)
meanshift = MeanShift(bandwidth=bandwidth,bin_seeding=True)
# 訓練模型
meanshift.fit(dataset_X)
print(meanshift.cluster_centers_)
print(meanshift.labels_)

此時打印除掉數據如下,

 

 

 

#最后一步,將圖形繪制出,查看一下效果

def visual_meanshift_effect(meanshift,dataset):
assert dataset.shape[1]==2,'only support dataset with 2 features'
X=dataset[:,0]
Y=dataset[:,1]
X_min,X_max=np.min(X)-1,np.max(X)+1
Y_min,Y_max=np.min(Y)-1,np.max(Y)+1
X_values,Y_values=np.meshgrid(np.arange(X_min,X_max,0.01),
np.arange(Y_min,Y_max,0.01))
# 預測網格點的標記
predict_labels=meanshift.predict(np.c_[X_values.ravel(),Y_values.ravel()])
predict_labels=predict_labels.reshape(X_values.shape)
plt.figure()
plt.imshow(predict_labels,interpolation='nearest',
extent=(X_values.min(),X_values.max(),
Y_values.min(),Y_values.max()),
cmap=plt.cm.Paired,
aspect='auto',
origin='lower')

# 將數據集繪制到圖表中
plt.scatter(X,Y,marker='v',facecolors='none',edgecolors='k',s=30)

# 將中心點繪制到圖中
centroids=meanshift.cluster_centers_
plt.scatter(centroids[:,0],centroids[:,1],marker='o',
s=100,linewidths=2,color='k',zorder=5,facecolors='b')
plt.title('MeanShift effect graph')
plt.xlim(X_min,X_max)
plt.ylim(Y_min,Y_max)
plt.xlabel('feature_0')
plt.ylabel('feature_1')
plt.show()
visual_meanshift_effect(meanshift,dataset_X)

 

 

 

 

 

 

 
 
 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM