一、算法簡介
Affinity Propagation聚類算法簡稱AP,是一個在07年發表在Science上的聚類算法。它實際屬於message-passing algorithms的一種。算法的基本思想將數據看成網絡中的節點,通過在數據點之間傳遞消息,分別是吸引度(responsibility)和歸屬度(availability),不斷修改聚類中心的數量與位置,直到整個數據集相似度達到最大,同時產生高聚類中心,並將其余各點分配到相應的聚類中。
二、算法描述
1、相關概念
-
Exemplar:指的是聚類中心,該聚類中心實際存在,並不是如同K-Means算法由計算生成的。
-
Similarity:數據點i和點j的相似度記為s(i, j),是指點j作為點i的聚類中心的相似度。一般使用歐氏距離來計算;相似度值越大說明點與點的距離越近,便於后面的比較計算。
-
Preference:數據點i的參考度稱為p(i)或s(i,i),是指點i作為聚類中心的參考度。一般取s相似度值的中值。
-
Responsibility:r(i,k)用來描述點k適合作為數據點i的聚類中心的程度。
-
Availability:a(i,k)用來描述點i選擇點k作為其聚類中心的適合程度。
-
Damping factor(阻尼因子)λ:主要是起收斂作用的。
2、算法步驟
2.1 具體算法步驟
AP算法可能需要指定一些參數,如Preference與Damping factor與最大迭代次數maxIterNum.
step 1: 初始化參數Damping factor與maxIterNum,並讀取數據;
step 2:計算相似度矩陣Similarity[i,j],一般使用歐氏距離,並求出相似度矩陣的中位值並賦給Preference;
step 3: 更新吸引度矩陣;
step 4: 更新歸屬度矩陣;
setp 4: 判斷是否達到最大迭代次數或達到終止條件,如未達到跳轉step 2,否則繼續下一步;
setp 5: 生成最終Exemplar,並將各數據分配到相應的聚類中。
2.2 算法詳解
AP算法有兩個關鍵步驟,即更新吸引度矩陣與更新歸屬度矩陣。
更新吸引度矩陣:
更新歸屬度矩陣:
為了避免振盪,AP算法更新信息時引入了衰減系數λ。每條信息被設置為它前次迭代更新值的倍加上本次信息更新值的1-倍。其中,衰減系數
λ是介於0到1之間的實數。即第t+1次r(i,k)與a(i,k)的迭代值:
2.3 算法優缺點
優點:
-
不需要事先指定聚類的數量
-
聚類結果很穩定
-
適用於非對稱相似性矩陣
-
初始值不敏感
缺點:
-
算法復雜度較高,為O(N*N*logN),該算法比較慢,對於大量數據,計算很久
三、算法實現(Java)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
|
package
cang.algorithms.clustering.ap;
import
java.io.BufferedReader;
import
java.io.FileReader;
import
java.io.IOException;
import
java.util.ArrayList;
import
java.util.Arrays;
import
java.util.Collections;
import
java.util.HashMap;
import
java.util.List;
import
java.util.Map;
import
java.util.Map.Entry;
/**
* 近鄰傳播算法,半監督聚類算法<br>
* 優點:不需事先指定類的個數;對初值的選取不敏感;對距離矩陣的對稱性沒要求<br>
* 缺點:算法復雜度較高,為O(N*N*logN)
*
* @author cang
*
*/
public
class
AP {
private
int
maxIterNum;
// 聚類結果不變次數
private
int
changedCount;
private
int
unchangeNum;
private
int
dataNum;
private
Point[] dataset;
// 相似度矩陣,數據點i和點j的相似度記為s(i, j),是指點j作為點i的聚類中心的相似度
private
double
similar[][];
// 吸引信息矩陣,r(i,k)用來描述點k適合作為數據點i的聚類中心的程度
private
double
r[][];
// 歸屬信息矩陣,a(i,k)用來描述點i選擇點k作為其聚類中心的適合程度
private
double
a[][];
// 衰減系數,主要是起收斂作用的
private
double
lambda;
// 聚類中心
private
List<Integer> exemplar;
private
List<Integer> oldExemplar;
public
AP() {
this
(
1000
,
0.9
);
}
public
AP(
int
maxIterNum,
double
lambda) {
this
.maxIterNum = maxIterNum;
this
.lambda = lambda;
}
/**
* 數據初始化
*/
public
void
init() {
oldExemplar =
new
ArrayList<Integer>();
exemplar =
new
ArrayList<Integer>();
similar =
new
double
[dataNum][dataNum];
r =
new
double
[dataNum][dataNum];
a =
new
double
[dataNum][dataNum];
for
(
int
i =
0
; i < dataset.length; i++) {
for
(
int
j = i +
1
; j < dataset.length; j++) {
similar[i][j] = distance(dataset[i].dimensioin,
dataset[j].dimensioin);
similar[j][i] = similar[i][j];
}
}
setPreference(
3
);
}
/**
* 獲取數據點i的參考度<br>
* 稱為p(i)或s(i,i) 是指點i作為聚類中心的參考度。一般取s相似度值的中值
*
* @param prefType 參考度類型
*/
private
void
setPreference(
int
prefType) {
List<Double> list =
new
ArrayList<Double>();
// find the median
for
(
int
i =
0
; i < dataNum; i++) {
for
(
int
j = i +
1
; j < dataNum; j++) {
list.add(similar[i][j]);
}
}
Collections.sort(list);
double
pref =
0
;
// use the median as preference
if
(prefType ==
1
) {
if
(list.size() %
2
==
0
) {
pref = (list.get(list.size() /
2
)
+ list.get(list.size() /
2
-
1
)) /
2
;
}
else
{
pref = list.get((list.size()) /
2
);
}
// use the minimum as preference
}
else
if
(prefType ==
2
) {
pref = list.get(
0
);
// use the 0.5 * (min + max) as preference
}
else
if
(prefType ==
3
) {
pref = list.get(
0
)
+ (list.get(list.size() -
1
) + list.get(
0
)) *
0.5
;
// use the maximum as preference
}
else
if
(prefType ==
4
) {
pref = list.get(list.size() -
1
);
}
else
{
System.out.println(
"prefType error"
);
System.exit(-
1
);
}
System.out.println(pref);
for
(
int
i =
0
; i < dataNum; i++) {
similar[i][i] = pref;
}
}
public
void
clustering() {
for
(
int
i =
0
; i < maxIterNum; i++) {
updateResponsible();
updateAvailable();
oldExemplar.clear();
if
(!exemplar.isEmpty()) {
for
(Integer v : exemplar) {
oldExemplar.add(v);
}
}
exemplar.clear();
changedCount =
0
;
// 獲取聚類中心
for
(
int
k =
0
; k < dataNum; k++) {
if
(r[k][k] + a[k][k] >
0
) {
exemplar.add(k);
}
}
// data point assignment
assignCluster();
if
(changedCount ==
0
) {
unchangeNum++;
if
(unchangeNum >
10
) {
maxIterNum = i;
break
;
}
}
else
{
unchangeNum =
0
;
}
}
// 生成預測標簽
setPredictLabel();
}
/**
* 將各數據點分配到聚類中心
*/
private
void
assignCluster() {
for
(
int
i =
0
; i < dataNum; i++) {
double
max = Double.MIN_VALUE;
int
index =
0
;
for
(Integer k : exemplar) {
if
(max < similar[i][k]) {
max = similar[i][k];
index = k;
}
}
if
(dataset[i].cid != index) {
dataset[i].cid = index;
changedCount++;
}
}
}
/**
* 更新吸引信息矩陣
*/
private
void
updateResponsible() {
for
(
int
i =
0
; i < dataNum; i++) {
for
(
int
k =
0
; k < dataNum; k++) {
double
max = Double.MIN_VALUE;
for
(
int
j =
0
; j < dataNum; j++) {
if
(j != k) {
if
(max < a[i][j] + similar[i][j]) {
max = a[i][j] + similar[i][j];
}
}
}
r[i][k] = (
1
- lambda) * (similar[i][k] - max)
+ lambda * r[i][k];
}
}
}
/**
* 更新歸屬信息矩陣
*/
private
void
updateAvailable() {
for
(
int
i =
0
; i < dataNum; i++) {
for
(
int
k =
0
; k < dataNum; k++) {
if
(i == k) {
double
sum =
0
;
for
(
int
j =
0
; j < dataNum; j++) {
if
(j != k) {
if
(r[j][k] >
0
) {
sum += r[j][k];
}
}
}
a[k][k] = sum;
}
else
{
double
sum =
0
;
for
(
int
j =
0
; j < dataNum; j++) {
if
(j != i && j != k) {
if
(r[j][k] >
0
) {
sum += r[j][k];
}
}
}
a[i][k] = (
1
- lambda) * (r[k][k] + sum) + lambda * a[i][k];
if
(a[i][k] >
0
) {
a[i][k] =
0
;
}
}
}
}
}
/**
* 生成數據點的聚類標簽
*/
private
void
setPredictLabel() {
Map<Integer, String> labelMap =
new
HashMap<Integer, String>();
for
(
int
cid : exemplar) {
Map<String, Integer> tempMap =
new
HashMap<String, Integer>();
for
(Point p : dataset) {
if
(cid == p.cid) {
if
(tempMap.get(p.label) ==
null
) {
tempMap.put(p.label,
1
);
}
else
{
tempMap.put(p.label, tempMap.get(p.label) +
1
);
}
}
}
String maxLabel =
null
;
int
temp =
0
;
for
(Entry<String, Integer> iter : tempMap.entrySet()) {
if
(temp < iter.getValue()) {
temp = iter.getValue();
maxLabel = iter.getKey();
}
}
labelMap.put(cid, maxLabel);
}
for
(Point p : dataset) {
p.predictLabel = labelMap.get(p.cid);
}
}
/**
* 計算數據點之間的距離
*
* @param a 數據的坐標
* @param b 另一個數據的坐標
* @return
*/
private
double
distance(
double
[] a,
double
[] b) {
if
(a.length != b.length) {
throw
new
IllegalArgumentException(
"Arrry a not equal array b!"
);
}
double
sum =
0
;
for
(
int
i =
0
; i < a.length; i++) {
double
dp = a[i] - b[i];
sum += dp * dp;
}
return
(
double
) Math.sqrt(sum);
}
/**
* 讀取數據集<br>
* 將數據集保存到數據集中
*
* @param fileName 文件名
* @param split 分隔符
* @param labelAtHead 標簽是否在頭部
* @throws IOException
*/
public
void
importDataWithLabel(String fileName, String split,
boolean
labelAtHead)
throws
IOException {
int
dimensionNum =
0
;
List<Point> dataList =
new
ArrayList<Point>();
// 讀取數據文件
BufferedReader reader =
new
BufferedReader(
new
FileReader(fileName));
String line =
null
;
while
((line = reader.readLine()) !=
null
) {
if
(line.trim().equals(
""
)) {
continue
;
}
// 字符串以split拆分
String[] splitStrs = line.split(split);
dimensionNum = splitStrs.length -
1
;
double
[] temp =
new
double
[dimensionNum];
String label = splitStrs[dimensionNum];
if
(labelAtHead) {
label = splitStrs[
0
];
for
(
int
i =
0
; i < dimensionNum; i++) {
temp[i] = Double.parseDouble(splitStrs[i +
1
]);
}
}
else
{
for
(
int
i =
0
; i < dimensionNum; i++) {
temp[i] = Double.parseDouble(splitStrs[i]);
}
}
dataList.add(
new
Point(temp, label));
dataNum++;
}
reader.close();
Collections.shuffle(dataList);
dataset =
new
Point[dataList.size()];
dataList.toArray(dataset);
}
/**
* 打印輸出聚類信息
*/
public
void
printInfo() {
System.out.println(
"迭代次數:"
+ maxIterNum);
System.out.println(
"聚類數目為:"
+ exemplar.size());
for
(
int
j =
0
; j < exemplar.size(); j++) {
System.out.println(j +
": "
+ exemplar.get(j));
}
for
(Point point : dataset) {
System.out.println(point);
}
}
static
class
Point {
// 數據標簽
private
String label;
// 聚類預測的標簽
private
String predictLabel;
// 數據點所屬簇id
private
int
cid;
// 數據點的維度
private
double
dimensioin[];
public
Point(
double
dimensioin[], String label) {
this
.label = label;
init(dimensioin);
}
public
Point(
double
dimensioin[]) {
init(dimensioin);
}
public
void
init(
double
value[]) {
dimensioin =
new
double
[value.length];
for
(
int
i =
0
; i < value.length; i++) {
dimensioin[i] = value[i];
}
}
@Override
public
String toString() {
return
"Point [label="
+ label +
", predictLabel="
+ predictLabel
+
", cid="
+ cid +
", dimensioin="
+ Arrays.toString(dimensioin) +
"]"
;
}
}
public
static
void
main(String[] args)
throws
IOException {
AP ap =
new
AP(
10000
,
0.6
);
ap.importDataWithLabel(FILEPATH,
","
,
false
);
ap.init();
ap.clustering();
ap.printInfo();
}
}
|