在使用多目標跟蹤算法時,接觸到了匈牙利匹配算法,一直沒時間好好總結下,現在來填坑。。
1. 基礎概念
1.1 二分圖
我們之前了解過圖(Graph)的概念,圖一般可以用G(V, E)來表示,V表示圖中的頂點,E表示圖中的邊。如下面,這個圖中有四個頂點,五條邊。
二分圖(Bipartite graph)是一類特殊的圖,它可以被划分為兩個部分,每個部分內的點互不相連,如下面是一個典型的二分圖,圖中的點可分為X,Y兩部分,X內部的點互補相連,Y內部的點也互不相連。我們也可以發現二分圖中一定不存在環。(二分圖又稱為二部圖,偶圖)
1.2 二分圖匹配
二分圖的匹配可以看成是二分圖的一個子圖,該子圖滿足以下條件:子圖中不存在有任意的兩條邊依附於同一個頂點
如下面左圖是一個二分圖,右圖就是它的一個匹配,右圖中每條邊都沒有公共端點,可以看出其是二分圖的一個子集。概念上有點繞,我們通俗點理解:有一個班級的學生要結成男女兩兩一組,但每個學生只想自己喜歡的異性結成一組,於是這就會有沖突,而匹配就是要找出這樣的男女組成,保證一個男生只和一個女生組合。
二分圖的匹配問題在有限資源分配時經常會用到,主要是為了保證某一個資源分且只分到某一個用戶的手中
1.3 二分圖最大匹配
二分圖最大匹配,就是在二分圖的所有匹配中,找出邊數最大的匹配。還是以上面的情景來理解:有一個班級的學生要結成男女兩兩一組,但每個學生只想自己喜歡的異性結成一組,匹配是保證一個男生只和一個女生組合,而最大匹配則是盡量保證沒有人落單,即二分圖最大匹配就是要給出一個最優方案,使得結成的組數最多
匈牙利算法就是尋找二分圖最大匹配方案的經典算法
1.4 二分圖最大權完美匹配
首先說二分圖完美匹配,如果一個二分圖的所有點都是匹配點(匹配邊中某一條邊的端點),則稱這個匹配是完美匹配。回到上面的情景,完美匹配就是可以得到一個方案,使得所有男女同學都可以結成兩兩一組。
- 完美匹配要求二分圖兩部分的點數相等,因為若X中包括4個點,Y中包含5個點,則Y中必然會有一個點不會被匹配
- 完美匹配一定是最大匹配,最大匹配不一定是完美匹配
二分圖最大權完美匹配:假定有一個二分圖 G,每條邊有一個權值(可為負數),權值和最大的完美匹配是二分圖最大權完美匹配。
還有一些概念,二分圖最優匹配,二分圖最大權值匹配,二分圖最小權值匹配(將權值轉化為負數,即轉為最大權值匹配),都是指二分圖最大權完美匹配。
求解二分圖最大權完美匹配一般采用KM(Kuhn-Munkres)匹配算法
2. 匈牙利匹配算法
參考:https://zhuanlan.zhihu.com/p/105212518, https://zhuanlan.zhihu.com/p/104901134?utm_source=wechat_session
2.1 匈牙利算法解析
匈牙利算法(Hungary Algorithm)是由Edmonds在1965年提出的,是求解二分圖最大匹配的經典算法,算法的核心就是根據一個初始匹配不停的找增廣路,直到沒有增廣路為止。幾個概念如下:
- 交替路:從任意一個未匹配點出發,依次經過未匹配邊-匹配邊-非匹配邊-匹配邊-未匹配邊……所得到的路徑被稱為交替路。(即未匹配邊和匹配邊交替出現)
- 增廣路:如果一條交替路的終點是一個未匹配點,那么這條路徑是增廣路,由於從未匹配點出發,又在未匹配點結束,未匹配邊比匹配邊多一條。
- 增廣路定理:如果可以找到一條增廣路,那么將匹配邊與未匹配邊互換,這個匹配就可以多一條邊,否則當前匹配就是最大匹配。即任意一個匹配是最大匹配的充分必要條件是不存在增廣路。
增廣路互換的實質可以這么考慮,如下圖:從未匹配點 A 出發,A 想與 B 匹配,於是通過未匹配邊找到 B,然而 B 已經是匹配點,於是只能經過匹配邊去問 C 能不能與別人匹配,C 經過未匹配邊找到 D,由於 D 是未匹配點,所以 C 成功與 D 匹配。CD 之間的邊變為匹配邊;BC 之間解除關系,變為未匹配邊;AB 之間建立關系,變為匹配邊。這便是增廣路互換的實質。
因此,總結下匈牙利算法的思想:就是不斷的尋找增廣路,如果找到,就互換匹配邊和非匹配邊,讓匹配邊增加一條,如果找不到匹配邊了,就表示已經是最大匹配了。
2.2 匈牙利算法代碼實現
python實現如下:
import math
import numpy as np
# 匈牙利匹配算法
class HungaryMatch(object):
def __init__(self, graph):
assert isinstance(graph, np.ndarray), print("二分圖的必須采用numpy array 格式")
assert graph.ndim == 2, print("二分圖的維度必須為2")
self.garph = graph
rows, cols = graph.shape
self.rows = rows
self.cols = cols
# self.vx = np.zeros(cols, dtype=np.int32) # visit flag, 橫向結點的訪問標志
# self.vy = np.zeros(rows, dtype=np.int32) # visit flag, 豎向結點的訪問標志
self.match_index = np.ones(cols, dtype=np.int32) * -1 # 橫向結點匹配的豎向結點的index (默認-1,表示未匹配任何豎向結點)
self.match_count = 0 # 總共有多少條匹配邊
def match(self):
for y in range(self.rows): # 從每一豎向結點開始,尋找增廣路
self.vx = np.zeros(self.cols, dtype=np.int32) # visit flag, 橫向結點的訪問標志置0
self.vy = np.zeros(self.rows, dtype=np.int32) # visit flag, 豎向結點的訪問標志置0
if self.dfs(y):
self.match_count += 1 # 采用dfs尋找增廣路,如果找到,匹配邊加1
return self.match_index, self.match_count
def dfs(self, y): # 遞歸版深度優先搜索
self.vy[y] = 1
for x in range(self.cols):
if self.vx[x] == 0 and self.garph[y][x] == 1: # 橫向結點x沒有訪問過,而且豎向結點y和橫向結點x有邊連接
self.vx[x] = 1
# 兩種情況:一是結點x沒有匹配,那么找到一條增廣路;二是X結點已經匹配,采用DFS,沿着X繼續往下走,最后若以未匹配點結束,則也是一條增廣路
if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
self.match_index[x] = y # 未匹配邊變成匹配邊
print(y, x, self.match_index)
return True
return False
if __name__ == '__main__':
graph = np.array([[0, 1, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 1, 0]])
hungary = HungaryMatch(graph)
index, count = hungary.match()
print(index) # [-1 1 2 0]:三組匹配邊(x, y): (1, 1), (2, 2), (3, 0)
print(count) # 3:共有三條匹配邊
cpp實現如下:
參考:https://zhuanlan.zhihu.com/p/104901134?utm_source=wechat_session
bool dfs(int x){
for(int i=0; i<m; i++){
if (edge[x][i]==0 || vis[i]) continue;
vis[i] = true;
if (y_match[i]==-1 || dfs(y_match[i]))
return true;
}
return false;
}
int cnt = 0;
for (int i=0; i<n; i++){
memset(vis, false, sizeof(vis));
if (dfs(i))
cnt++;
}
3. KM算法(Kuhn-Munkres Algorithm)
參考:https://blog.sengxian.com/algorithms/km,https://piggerzzm.github.io/2020/03/28/Kuhn-Munkres/
3.1 可行頂標和相等子圖
二分圖最優匹配(最大權值匹配)的經典算法是由Kuhn和Munkres獨立提出的KM算法,值得一提的是最初的KM算法是在1955年和1957年提出的,因此當時的KM算法是以矩陣為基礎的,隨着匈牙利算法被Edmonds提出之后,現有的KM算法利用匈牙利樹可以得到更漂亮的實現。
KM算法是通過給每個頂點一個標號(叫做頂標,或者節點函數)來把求最大權完美匹配的問題轉化為求完美匹配的問題的。可以簡單理解為節點函數就是節點的一個值。幾個概念如下:
- 頂標(節點函數):指的是圖中的每個頂點,給它賦予一個值(就像邊的權重值),這個值也稱為節點函數值。
- 可行頂標:對於所有頂點的函數值\(l\),使得對於任意邊 \(e(x \rightarrow y)\),都滿足 \(l_{x} + l_{y} \ge W_{e}\),(其中,\(l_x\)為頂點x的頂標,\(l_y\)為頂點y的頂標,\(w_e\)為邊\(e(x \rightarrow y)\)的權值)
- 相等子圖:相等子圖包含原圖中所有的點,但只包含滿足 \(l_{x} + l_{y} = W_{e}\)的所有邊 \(e(x \rightarrow y)\)。根據定義,這些邊一定是當前權值最大的邊(不等式已經取到等號),那么如果相等子圖有完美匹配,那這個完美匹配一定是最大權值完美匹配。因為相等子圖的權值和為所有點的頂標之和,而隨便一個匹配中的邊因為受到 \(W_{e} \le l_{x} + l_{y}\)的限制,不可能比所有點的頂標之和大。
3.2 KM算法步驟解析
KM算法的主要目標就在於尋找可行頂標,使得相等子圖有完美匹配。可行頂標的修改過程中,每一步都運用了貪心的思想,這樣我們的最終結果一定是最優的。下面是算法的敘述:
步驟一:頂標初始化
因為有 \(l_{x} + l_{y} = W_{e}\)恆成立,我們設左側(Y集)的所有節點頂標為 0,那么所有 X集的點的頂標就必須為從它出發所有的邊的權值最大值。
步驟二:尋找完美匹配
尋找當前頂標條件下, 采用增廣路定理對每個點進行匹配(匈牙利算法),若最大匹配就是完美匹配,結束算法,否則必須修改頂標,使得有更多的邊能夠參與進來。
步驟三:修改頂標,加入更多可行頂標及對應邊
我們求當前相等子圖的完美匹配失敗,是因為對於某個未匹配頂點 u,我們找不到一條從它出發的增廣路,這時我們只能獲得一條交替路。我們把 X集中在交替路的點集叫做 S, X集中不在交替路的點集叫做 S',同理 Y集中在交替路的點集叫做 T, Y集中不在交替路的點集叫做 T'。如果我們把交替路中 X 集頂點的頂標(點集S中的點)全都減小某個值 d,Y集的頂標(點集T中的點)全都增加同一個值 d,那么我們會發現:
- 兩端都在交替路中的邊 \(e(i \rightarrow j)\),\(l_{i} + l_{j}\) 的值沒有變化。也就是說,它原來屬於相等子圖,現在仍屬於相等子圖。
- 兩端都不在交替路中的邊 \(e(i \rightarrow j)\),\(l_{i}, l_{j}\) 都沒有變化,\(l_{i} + l_{j}\) 的值沒有變化。也就是說,它原來屬於(或不屬於)相等子圖,現在仍屬於(或不屬於)相等子圖。
- X集一端在 S' 中, Y端在 T中的邊 \(e(i \rightarrow j)\),它的 \(l_{i}\)不變, \(l_{j}\)增加了d,\(l_{i} + l_{j}\)的值有所增大。它原來不屬於相等子圖,現在仍不可能屬於相等子圖。
- X集一端在 S中,Y 端在 T'中的邊\(e(i \rightarrow j)\),它的 \(l_{i}\)減小了d, \(l_{j}\)不變,\(l_{i} + l_{j}\)的值有所減小。也就說,它原來不屬於相等子圖,現在可能進入了相等子圖,因而使相等子圖得到了擴大。
也就是說,只有 X集一端在 S 中,Y端在 T'中的邊才有可能被選中。繼續貪心,我們只能讓滿足條件的邊權最大的邊被選中,即滿足\(l_{x} + l_{y} = W_{e}\),那么這個 d 值,就應該取 \(d = \min\{l_{x} + l_{y} - W_{e(x\rightarrow y)}\ \vert \ x \in S, y \in T'\}\)。
於是有新的邊加入相等子圖,我們可以愉快的繼續對於未匹配頂點 u尋找增廣路,這樣的修改最多進行n次,而一共有 n個點,所以除去修改頂標的時間,復雜度已經達到\(O(n^{2})\)。因此算法的復雜度主要取決於修改頂標的時間, 修改頂標主要兩個思路:
- 思路一:枚舉所有\(n^{2}\)條邊,看是否滿足條件,滿足條件就更新d值。最直觀清晰,然而總的復雜度飆升至\(O(n^{4})\)。
- 思路二:對於T'的每個點v,定義松弛變量\(slack(v) = \min\{l_{x}+l_{y} -W_{e(x\rightarrow y)}\ \vert\ x\in S\}\),這個松弛變量在匹配的過程中就可以更新,修改頂標的過程中\(d = \min\{slack(v)\ \vert\ v \in T'\}\)。總復雜度\(O(n^{3})\),但不是嚴格的(想一想為什么)?
3.3 KM算法步驟總結
KM算法僅僅只適用於找二分圖最佳完美匹配,如果無完美匹配,那么算法很可能陷入死循環(如果不存在的邊為 -INF 的話就不會,但正確性就無法保證了),對於這種情況要小心處理。
最后回顧一下總的流程,理一下思路:
- 初始化可行頂標。
- 用增廣路定理尋對每個點找匹配。
- 若點未找到匹配則修改可行頂標的值。
- 重復2、3步直到所有點均有匹配為止,即找到相等子圖的完美匹配為止
3.4 KM代碼實現
3.4.1 python實現
\(O(n^{4})\)版本:
# Kuhn-Munkres匹配算法, O(n^4)時間復雜度
class KMMatchOriginal(object):
def __init__(self, graph):
assert isinstance(graph, np.ndarray), print("二分圖的必須采用numpy array 格式")
assert graph.ndim == 2, print("二分圖的維度必須為2")
self.graph = graph
rows, cols = graph.shape
self.rows = rows
self.cols = cols
self.lx = np.zeros(self.cols, dtype=np.float32) # 橫向結點的頂標
self.ly = np.zeros(self.rows, dtype=np.float32) # 豎向結點的頂標
self.match_index = np.ones(cols, dtype=np.int32) * -1 # 橫向結點匹配的豎向結點的index (默認-1,表示未匹配任何豎向結點)
self.match_weight = 0 # 匹配邊的權值之和
def match(self):
# 初始化頂標, ly初始化為0,lx初始化為節點對應權值最大邊的權值
for y in range(self.rows):
self.ly[y] = max(self.graph[y, :])
for y in range(self.rows): # 從每一豎向結點開始,尋找增廣路
while True:
self.vx = np.zeros(self.cols, dtype=np.int32) # 橫向結點的匹配標志
self.vy = np.zeros(self.rows, dtype=np.int32) # 豎向結點的匹配標志
if self.dfs(y):
break
else:
self.update()
return self.match_index
# 更新頂標
def update(self):
d = np.inf
# 尋找y中已匹配,x中未匹配,對應需要減小的最小權值
for y in range(self.rows):
if self.vy[y]:
for x in range(self.cols):
if not self.vx[x]:
d = min(d, self.lx[x] + self.ly[y] - self.graph[y][x])
for x in range(self.cols): # x頂標初始化值為0,因此所有匹配點頂標+d
if self.vx[x]:
self.lx[x] += d
for y in range(self.rows): # y頂標初始化值為對應邊的最大權值,因此所有匹配點頂標-d
if self.vy[y]:
self.ly[y] -= d
def dfs(self, y): # 遞歸版深度優先搜索
self.vy[y] = 1
for x in range(self.cols):
if self.vx[x] == 0 and self.lx[x] + self.ly[y] == self.graph[y][x]:
self.vx[x] = 1
# 兩種情況:一是結點x沒有匹配,那么找到一條增廣路;二是X結點已經匹配,采用DFS,沿着X繼續往下走,最后若以未匹配點結束,則也是一條增廣路
if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
self.match_index[x] = y # 未匹配邊變成匹配邊
return True
return False
if __name__ == '__main__':
graph = np.array([[2,1,1],[3,2,1],[1,1,1]])
kmo = KMMatchOriginal(graph)
print(kmo.match())
\(O(n^{3})\)版本:
# Kuhn-Munkres匹配算法
class KMMatch(object):
def __init__(self, graph):
assert isinstance(graph, np.ndarray), print("二分圖的必須采用numpy array 格式")
assert graph.ndim == 2, print("二分圖的維度必須為2")
self.graph = graph
rows, cols = graph.shape
self.rows = rows
self.cols = cols
self.lx = np.zeros(self.cols, dtype=np.float32) # 橫向結點的頂標
self.ly = np.zeros(self.rows, dtype=np.float32) # 豎向結點的頂標
self.match_index = np.ones(cols, dtype=np.int32) * -1 # 橫向結點匹配的豎向結點的index (默認-1,表示未匹配任何豎向結點)
self.match_weight = 0 # 匹配邊的權值之和
self.inc = math.inf
def match(self):
# 初始化頂標, lx初始化為0,ly初始化為節點對應權值最大邊的權值
for y in range(self.rows):
self.ly[y] = max(self.graph[y, :])
for y in range(self.rows): # 從每一豎向結點開始,尋找增廣路
while True:
self.inc = np.inf
self.vx = np.zeros(self.cols, dtype=np.int32) # 橫向結點的匹配標志
self.vy = np.zeros(self.rows, dtype=np.int32) # 豎向結點的匹配標志
if self.dfs(y):
break
else:
self.update()
# print(y, self.lx, self.ly, self.vx, self.vy)
return self.match_index
# 更新頂標
def update(self):
for x in range(self.cols):
if self.vx[x]:
self.lx[x] += self.inc
for y in range(self.rows):
if self.vy[y]:
self.ly[y] -= self.inc
def dfs(self, y): # 遞歸版深度優先搜索
self.vy[y] = 1
for x in range(self.cols):
if self.vx[x] == 0:
t = self.lx[x] + self.ly[y] - self.graph[y][x]
if t == 0:
self.vx[x] = 1
# 兩種情況:一是結點x沒有匹配,那么找到一條增廣路;二是X結點已經匹配,采用DFS,沿着X繼續往下走,最后若以未匹配點結束,則也是一條增廣路
if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
self.match_index[x] = y # 未匹配邊變成匹配邊
# print(y, x, self.match_index)
return True
else:
if self.inc > t:
self.inc = t
return False
if __name__ == '__main__':
graph = np.array([[2, 1, 1], [3, 2, 1], [1, 1, 1]])
# # graph = np.array([[3,4,6,4,9],[6,4,5,3,8],[7,5,3,4,2],[6,3,2,2,5],[8,4,5,4,7]])
km = KMMatch(graph)
print(km.match())
在代碼撰寫過程中,踩了幾個坑,也發現了一些問題,總結如下:
- 在初始化頂標時,若行結點初始化為最大邊權值,列結點初始化為0,則必須從行結點出發,遍歷尋找滿足條件的增廣路,否則代碼會陷入死循環。(即從初始化為最大邊權值的結點開始遍歷)
- KM算法要求行結點和列結點個數相同,如果不相同時,保證行結點個數少,列結點個數多,然后通過padding來使行結點和列結點個數相同。
- KM算法求最大權值匹配,若要求最小權值匹配,可以對權值矩陣進行轉換,如采用一個很大值(如sys.maxint)減去權值矩陣
3.4.2 cpp代碼實現
\(O(n^{4})\)版本:
int Weight[maxm][maxn];
int Lx[maxm], Ly[maxn]; // 頂標
int match[maxn]; // 記錄匹配
bool S[maxm], T[maxn]; // 算法中的兩個集合S和T
// 步驟 1: 初始化可行頂標和初始化匹配
void Init()
{
// 將X集合的頂標設為最大邊權,Y集合的頂標設為0
for (int i = 1; i <= m; i++)
{
Lx[i] = 0;
for (int j = 1; j <= n; j++)
{
match[j] = 0; // match記錄的是Y集合里的點與誰匹配
Ly[j] = 0;
Lx[i] = max(Lx[i], Weight[i][j]);
}
}
}
//步驟2:增廣路定理尋找匹配點(匈牙利算法中的DFS)
bool findPath(int i)
{
S[i] = true;
for (int j = 1; j <= n; j++)
{
if (Lx[i] + Ly[j] == Weight[i][j] && !T[j]) // 找出在相等子圖里又還未被標記的邊
{
T[j] = true;
if (!match[j] || findPath(match[j])) // 未被匹配,或者已經匹配又找到增廣路
{
match[j] = i;
return true;
}
}
}
return false;
}
//步驟 3: 更新頂標
void update()
{
// 計算a
int a = 1 << 30;
for (int i = 1; i <= m; i++)
if (S[i])
for (int j = 1; j <= n; j++)
if (!T[j])
a = min(a, Lx[i] + Ly[j] - Weight[i][j]);
// 修改頂標
for (int i = 1; i <= m; i++)
if (S[i])
Lx[i] -= a;
for (int j = 1; j <= n; j++)
if (T[j])
Ly[j] += a;
}
// 整體的KM算法
void KM()
{
Init();
for (int i = 1; i <= m; i++)
{
while (true)
{
for (int i = 1; i <= m; i++)
S[i] = 0;
for (int j = 1; j <= n; j++)
T[j] = 0;
if (!findPath(i))
update();
else
break;
}
}
}
\(O(n^{3})\)版本:
const int maxn = 500 + 3, INF = 0x3f3f3f3f;
int n, W[maxn][maxn];
int mat[maxn];
int Lx[maxn], Ly[maxn], slack[maxn];
bool S[maxn], T[maxn];
inline void tension(int &a, const int b) {
if(b < a) a = b;
}
inline bool match(int u) {
S[u] = true;
for(int v = 0; v < n; ++v) {
if(T[v]) continue;
int t = Lx[u] + Ly[v] - W[u][v];
if(!t) {
T[v] = true;
if(mat[v] == -1 || match(mat[v])) {
mat[v] = u;
return true;
}
}else tension(slack[v], t);
}
return false;
}
inline void update() {
int d = INF;
for(int i = 0; i < n; ++i)
if(!T[i]) tension(d, slack[i]);
for(int i = 0; i < n; ++i) {
if(S[i]) Lx[i] -= d;
if(T[i]) Ly[i] += d;
}
}
inline void KM() {
for(int i = 0; i < n; ++i) {
Lx[i] = Ly[i] = 0; mat[i] = -1;
for(int j = 0; j < n; ++j) Lx[i] = max(Lx[i], W[i][j]);
}
for(int i = 0; i < n; ++i) {
fill(slack, slack + n, INF);
while(true) {
for(int j = 0; j < n; ++j) S[j] = T[j] = false;
if(match(i)) break;
else update();
}
}
}
參考:https://nymrli.top/2019/12/05/KM-Kuhn-Munkres-算法/
https://piggerzzm.github.io/2020/03/28/Kuhn-Munkres/
https://www.cnblogs.com/xingnie/p/10395788.html
4. Kuhn-Munkres算法開源包
在實際項目中涉及到最大權值匹配問題時,可以采用開源包中的Kuhn-Munkres算法,如下面兩個:
munkres
python有實現了munkres算法的安裝包,可以直接安裝:pip install munkres
官方使用文檔:https://software.clapper.org/munkres/
scipy
scipy模塊中scipy.optimize.linear_sum_assignment實現了KM匹配算法,可以直接調用。