AP聚類算法


一、算法簡介

Affinity Propagation聚類算法簡稱AP,是一個在07年發表在Science上的聚類算法。它實際屬於message-passing algorithms的一種。算法的基本思想將數據看成網絡中的節點,通過在數據點之間傳遞消息,分別是吸引度(responsibility)和歸屬度(availability),不斷修改聚類中心的數量與位置,直到整個數據集相似度達到最大,同時產生高聚類中心,並將其余各點分配到相應的聚類中。

二、算法描述

1、相關概念

  • Exemplar:指的是聚類中心,該聚類中心實際存在,並不是如同K-Means算法由計算生成的。 

  • Similarity:數據點i和點j的相似度記為s(i, j),是指點j作為點i的聚類中心的相似度。一般使用歐氏距離來計算;相似度值越大說明點與點的距離越近,便於后面的比較計算。 

  • Preference:數據點i的參考度稱為p(i)或s(i,i),是指點i作為聚類中心的參考度。一般取s相似度值的中值。 

  • Responsibility:r(i,k)用來描述點k適合作為數據點i的聚類中心的程度。 

  • Availability:a(i,k)用來描述點i選擇點k作為其聚類中心的適合程度。 

  • Damping factor(阻尼因子)λ:主要是起收斂作用的。

2、算法步驟

2.1 具體算法步驟

AP算法可能需要指定一些參數,如PreferenceDamping factor與最大迭代次數maxIterNum.

step 1: 初始化參數Damping factormaxIterNum,並讀取數據;

step 2:計算相似度矩陣Similarity[i,j],一般使用歐氏距離,並求出相似度矩陣的中位值並賦給Preference;

step 3: 更新吸引度矩陣;

step 4: 更新歸屬度矩陣;

setp 4: 判斷是否達到最大迭代次數或達到終止條件,如未達到跳轉step 2,否則繼續下一步;

setp 5: 生成最終Exemplar,並將各數據分配到相應的聚類中。

2.2 算法詳解

AP算法有兩個關鍵步驟,即更新吸引度矩陣與更新歸屬度矩陣。

更新吸引度矩陣:


更新歸屬度矩陣:


為了避免振盪,AP算法更新信息時引入了衰減系數λ。每條信息被設置為它前次迭代更新值的λ倍加上本次信息更新值的1-λ倍。其中,衰減系數

λλ是介於01之間的實數。即第t+1次r(i,k)與a(i,k)的迭代值

 

 

2.3 算法優缺點

優點:

  • 不需要事先指定聚類的數量

  • 聚類結果很穩定

  • 適用於非對稱相似性矩陣

  • 初始值不敏感

缺點:

  • 算法復雜度較高,為O(N*N*logN),該算法比較慢,對於大量數據,計算很久

三、算法實現(Java)

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
package cang.algorithms.clustering.ap;
 
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
 
/**
  * 近鄰傳播算法,半監督聚類算法<br>
  * 優點:不需事先指定類的個數;對初值的選取不敏感;對距離矩陣的對稱性沒要求<br>
  * 缺點:算法復雜度較高,為O(N*N*logN)
  *
  * @author cang
  *
  */
public class AP {
 
     private int maxIterNum;
     // 聚類結果不變次數
     private int changedCount;
     private int unchangeNum;
     private int dataNum;
     private Point[] dataset;
     // 相似度矩陣,數據點i和點j的相似度記為s(i, j),是指點j作為點i的聚類中心的相似度
     private double similar[][];
     // 吸引信息矩陣,r(i,k)用來描述點k適合作為數據點i的聚類中心的程度
     private double r[][];
     // 歸屬信息矩陣,a(i,k)用來描述點i選擇點k作為其聚類中心的適合程度
     private double a[][];
     // 衰減系數,主要是起收斂作用的
     private double lambda;
     // 聚類中心
     private List<Integer> exemplar;
     private List<Integer> oldExemplar;
 
     public AP() {
         this ( 1000 , 0.9 );
     }
 
     public AP( int maxIterNum, double lambda) {
         this .maxIterNum = maxIterNum;
         this .lambda = lambda;
     }
 
     /**
      * 數據初始化
      */
     public void init() {
         oldExemplar = new ArrayList<Integer>();
         exemplar = new ArrayList<Integer>();
         similar = new double [dataNum][dataNum];
         r = new double [dataNum][dataNum];
         a = new double [dataNum][dataNum];
         for ( int i = 0 ; i < dataset.length; i++) {
             for ( int j = i + 1 ; j < dataset.length; j++) {
                 similar[i][j] = distance(dataset[i].dimensioin,
                         dataset[j].dimensioin);
                 similar[j][i] = similar[i][j];
             }
         }
         setPreference( 3 );
     }
 
     /**
      * 獲取數據點i的參考度<br>
      * 稱為p(i)或s(i,i) 是指點i作為聚類中心的參考度。一般取s相似度值的中值
      *
      * @param prefType 參考度類型
      */
     private void setPreference( int prefType) {
         List<Double> list = new ArrayList<Double>();
         // find the median
         for ( int i = 0 ; i < dataNum; i++) {
             for ( int j = i + 1 ; j < dataNum; j++) {
                 list.add(similar[i][j]);
             }
         }
         Collections.sort(list);
         double pref = 0 ;
         // use the median as preference
         if (prefType == 1 ) {
             if (list.size() % 2 == 0 ) {
                 pref = (list.get(list.size() / 2 )
                         + list.get(list.size() / 2 - 1 )) / 2 ;
             } else {
                 pref = list.get((list.size()) / 2 );
             }
             // use the minimum as preference
         } else if (prefType == 2 ) {
             pref = list.get( 0 );
             // use the 0.5 * (min + max) as preference
         } else if (prefType == 3 ) {
             pref = list.get( 0 )
                     + (list.get(list.size() - 1 ) + list.get( 0 )) * 0.5 ;
             // use the maximum as preference
         } else if (prefType == 4 ) {
             pref = list.get(list.size() - 1 );
         } else {
             System.out.println( "prefType error" );
             System.exit(- 1 );
         }
         System.out.println(pref);
         for ( int i = 0 ; i < dataNum; i++) {
             similar[i][i] = pref;
         }
     }
 
     public void clustering() {
         for ( int i = 0 ; i < maxIterNum; i++) {
             updateResponsible();
             updateAvailable();
 
             oldExemplar.clear();
             if (!exemplar.isEmpty()) {
                 for (Integer v : exemplar) {
                     oldExemplar.add(v);
                 }
             }
             exemplar.clear();
 
             changedCount = 0 ;
             // 獲取聚類中心
             for ( int k = 0 ; k < dataNum; k++) {
                 if (r[k][k] + a[k][k] > 0 ) {
                     exemplar.add(k);
                 }
             }
             // data point assignment
             assignCluster();
 
             if (changedCount == 0 ) {
                 unchangeNum++;
                 if (unchangeNum > 10 ) {
                     maxIterNum = i;
                     break ;
                 }
             } else {
                 unchangeNum = 0 ;
             }
 
         }
         // 生成預測標簽
         setPredictLabel();
     }
 
     /**
      * 將各數據點分配到聚類中心
      */
     private void assignCluster() {
         for ( int i = 0 ; i < dataNum; i++) {
             double max = Double.MIN_VALUE;
             int index = 0 ;
             for (Integer k : exemplar) {
                 if (max < similar[i][k]) {
                     max = similar[i][k];
                     index = k;
                 }
             }
             if (dataset[i].cid != index) {
                 dataset[i].cid = index;
                 changedCount++;
             }
         }
     }
 
     /**
      * 更新吸引信息矩陣
      */
     private void updateResponsible() {
         for ( int i = 0 ; i < dataNum; i++) {
             for ( int k = 0 ; k < dataNum; k++) {
                 double max = Double.MIN_VALUE;
                 for ( int j = 0 ; j < dataNum; j++) {
                     if (j != k) {
                         if (max < a[i][j] + similar[i][j]) {
                             max = a[i][j] + similar[i][j];
                         }
                     }
                 }
                 r[i][k] = ( 1 - lambda) * (similar[i][k] - max)
                         + lambda * r[i][k];
             }
         }
     }
 
     /**
      * 更新歸屬信息矩陣
      */
     private void updateAvailable() {
         for ( int i = 0 ; i < dataNum; i++) {
             for ( int k = 0 ; k < dataNum; k++) {
                 if (i == k) {
                     double sum = 0 ;
                     for ( int j = 0 ; j < dataNum; j++) {
                         if (j != k) {
                             if (r[j][k] > 0 ) {
                                 sum += r[j][k];
                             }
                         }
                     }
                     a[k][k] = sum;
                 } else {
                     double sum = 0 ;
                     for ( int j = 0 ; j < dataNum; j++) {
                         if (j != i && j != k) {
                             if (r[j][k] > 0 ) {
                                 sum += r[j][k];
                             }
                         }
                     }
                     a[i][k] = ( 1 - lambda) * (r[k][k] + sum) + lambda * a[i][k];
                     if (a[i][k] > 0 ) {
                         a[i][k] = 0 ;
                     }
                 }
             }
         }
     }
 
     /**
      * 生成數據點的聚類標簽
      */
     private void setPredictLabel() {
         Map<Integer, String> labelMap = new HashMap<Integer, String>();
         for ( int cid : exemplar) {
             Map<String, Integer> tempMap = new HashMap<String, Integer>();
             for (Point p : dataset) {
                 if (cid == p.cid) {
                     if (tempMap.get(p.label) == null ) {
                         tempMap.put(p.label, 1 );
                     } else {
                         tempMap.put(p.label, tempMap.get(p.label) + 1 );
                     }
                 }
             }
             String maxLabel = null ;
             int temp = 0 ;
             for (Entry<String, Integer> iter : tempMap.entrySet()) {
                 if (temp < iter.getValue()) {
                     temp = iter.getValue();
                     maxLabel = iter.getKey();
                 }
             }
             labelMap.put(cid, maxLabel);
         }
 
         for (Point p : dataset) {
             p.predictLabel = labelMap.get(p.cid);
         }
     }
 
     /**
      * 計算數據點之間的距離
      *
      * @param a 數據的坐標
      * @param b 另一個數據的坐標
      * @return
      */
     private double distance( double [] a, double [] b) {
         if (a.length != b.length) {
             throw new IllegalArgumentException( "Arrry a not equal array b!" );
         }
         double sum = 0 ;
         for ( int i = 0 ; i < a.length; i++) {
             double dp = a[i] - b[i];
             sum += dp * dp;
         }
         return ( double ) Math.sqrt(sum);
     }
 
 
     /**
      * 讀取數據集<br>
      * 將數據集保存到數據集中
      *
      * @param fileName 文件名
      * @param split 分隔符
      * @param labelAtHead 標簽是否在頭部
      * @throws IOException
      */
     public void importDataWithLabel(String fileName, String split,
             boolean labelAtHead) throws IOException {
         int dimensionNum = 0 ;
         List<Point> dataList = new ArrayList<Point>();
         // 讀取數據文件
         BufferedReader reader = new BufferedReader( new FileReader(fileName));
         String line = null ;
         while ((line = reader.readLine()) != null ) {
             if (line.trim().equals( "" )) {
                 continue ;
             }
             // 字符串以split拆分
             String[] splitStrs = line.split(split);
             dimensionNum = splitStrs.length - 1 ;
             double [] temp = new double [dimensionNum];
 
             String label = splitStrs[dimensionNum];
             if (labelAtHead) {
                 label = splitStrs[ 0 ];
                 for ( int i = 0 ; i < dimensionNum; i++) {
                     temp[i] = Double.parseDouble(splitStrs[i + 1 ]);
                 }
             } else {
                 for ( int i = 0 ; i < dimensionNum; i++) {
                     temp[i] = Double.parseDouble(splitStrs[i]);
                 }
             }
             dataList.add( new Point(temp, label));
             dataNum++;
         }
         reader.close();
         Collections.shuffle(dataList);
         dataset = new Point[dataList.size()];
         dataList.toArray(dataset);
     }
 
     /**
      * 打印輸出聚類信息
      */
     public void printInfo() {
         System.out.println( "迭代次數:" + maxIterNum);
         System.out.println( "聚類數目為:" + exemplar.size());
         for ( int j = 0 ; j < exemplar.size(); j++) {
             System.out.println(j + ": " + exemplar.get(j));
         }
         for (Point point : dataset) {
             System.out.println(point);
         }
     }
 
 
     static class Point {
         // 數據標簽
         private String label;
         // 聚類預測的標簽
         private String predictLabel;
         // 數據點所屬簇id
         private int cid;
         // 數據點的維度
         private double dimensioin[];
 
         public Point( double dimensioin[], String label) {
             this .label = label;
             init(dimensioin);
         }
 
         public Point( double dimensioin[]) {
             init(dimensioin);
         }
 
         public void init( double value[]) {
             dimensioin = new double [value.length];
             for ( int i = 0 ; i < value.length; i++) {
                 dimensioin[i] = value[i];
             }
         }
 
         @Override
         public String toString() {
             return "Point [label=" + label + ", predictLabel=" + predictLabel
                     + ", cid=" + cid + ", dimensioin="
                     + Arrays.toString(dimensioin) + "]" ;
         }
 
     }
 
public static void main(String[] args) throws IOException {
         AP ap = new AP( 10000 , 0.6 );
         ap.importDataWithLabel(FILEPATH, "," , false );
         ap.init();
         ap.clustering();
         ap.printInfo();
     }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM