匈牙利匹配和最大權值匹配算法


在使用多目標跟蹤算法時,接觸到了匈牙利匹配算法,一直沒時間好好總結下,現在來填坑。。

1. 基礎概念

1.1 二分圖

我們之前了解過圖(Graph)的概念,圖一般可以用G(V, E)來表示,V表示圖中的頂點,E表示圖中的邊。如下面,這個圖中有四個頂點,五條邊。

二分圖(Bipartite graph)是一類特殊的圖,它可以被划分為兩個部分,每個部分內的點互不相連,如下面是一個典型的二分圖,圖中的點可分為X,Y兩部分,X內部的點互補相連,Y內部的點也互不相連。我們也可以發現二分圖中一定不存在環。(二分圖又稱為二部圖,偶圖)

1.2 二分圖匹配

二分圖的匹配可以看成是二分圖的一個子圖,該子圖滿足以下條件:子圖中不存在有任意的兩條邊依附於同一個頂點

如下面左圖是一個二分圖,右圖就是它的一個匹配,右圖中每條邊都沒有公共端點,可以看出其是二分圖的一個子集。概念上有點繞,我們通俗點理解:有一個班級的學生要結成男女兩兩一組,但每個學生只想自己喜歡的異性結成一組,於是這就會有沖突,而匹配就是要找出這樣的男女組成,保證一個男生只和一個女生組合。

二分圖的匹配問題在有限資源分配時經常會用到,主要是為了保證某一個資源分且只分到某一個用戶的手中

1.3 二分圖最大匹配

二分圖最大匹配,就是在二分圖的所有匹配中,找出邊數最大的匹配。還是以上面的情景來理解:有一個班級的學生要結成男女兩兩一組,但每個學生只想自己喜歡的異性結成一組,匹配是保證一個男生只和一個女生組合,而最大匹配則是盡量保證沒有人落單,即二分圖最大匹配就是要給出一個最優方案,使得結成的組數最多

匈牙利算法就是尋找二分圖最大匹配方案的經典算法

1.4 二分圖最大權完美匹配

首先說二分圖完美匹配,如果一個二分圖的所有點都是匹配點(匹配邊中某一條邊的端點),則稱這個匹配是完美匹配。回到上面的情景,完美匹配就是可以得到一個方案,使得所有男女同學都可以結成兩兩一組。

  • 完美匹配要求二分圖兩部分的點數相等,因為若X中包括4個點,Y中包含5個點,則Y中必然會有一個點不會被匹配
  • 完美匹配一定是最大匹配,最大匹配不一定是完美匹配

二分圖最大權完美匹配:假定有一個二分圖 G,每條邊有一個權值(可為負數),權值和最大的完美匹配是二分圖最大權完美匹配。

還有一些概念,二分圖最優匹配,二分圖最大權值匹配,二分圖最小權值匹配(將權值轉化為負數,即轉為最大權值匹配),都是指二分圖最大權完美匹配。

求解二分圖最大權完美匹配一般采用KM(Kuhn-Munkres)匹配算法

2. 匈牙利匹配算法

參考:https://zhuanlan.zhihu.com/p/105212518, https://zhuanlan.zhihu.com/p/104901134?utm_source=wechat_session

2.1 匈牙利算法解析

匈牙利算法(Hungary Algorithm)是由Edmonds在1965年提出的,是求解二分圖最大匹配的經典算法,算法的核心就是根據一個初始匹配不停的找增廣路,直到沒有增廣路為止。幾個概念如下:

  • 交替路:從任意一個未匹配點出發,依次經過未匹配邊-匹配邊-非匹配邊-匹配邊-未匹配邊……所得到的路徑被稱為交替路。(即未匹配邊和匹配邊交替出現)
  • 增廣路:如果一條交替路的終點是一個未匹配點,那么這條路徑是增廣路,由於從未匹配點出發,又在未匹配點結束,未匹配邊比匹配邊多一條。
  • 增廣路定理:如果可以找到一條增廣路,那么將匹配邊與未匹配邊互換,這個匹配就可以多一條邊,否則當前匹配就是最大匹配。即任意一個匹配是最大匹配的充分必要條件是不存在增廣路。

增廣路互換的實質可以這么考慮,如下圖:從未匹配點 A 出發,A 想與 B 匹配,於是通過未匹配邊找到 B,然而 B 已經是匹配點,於是只能經過匹配邊去問 C 能不能與別人匹配,C 經過未匹配邊找到 D,由於 D 是未匹配點,所以 C 成功與 D 匹配。CD 之間的邊變為匹配邊;BC 之間解除關系,變為未匹配邊;AB 之間建立關系,變為匹配邊。這便是增廣路互換的實質。

因此,總結下匈牙利算法的思想:就是不斷的尋找增廣路,如果找到,就互換匹配邊和非匹配邊,讓匹配邊增加一條,如果找不到匹配邊了,就表示已經是最大匹配了。

2.2 匈牙利算法代碼實現

python實現如下:

import math
import numpy as np

# 匈牙利匹配算法
class HungaryMatch(object):

    def __init__(self, graph):
        assert isinstance(graph, np.ndarray), print("二分圖的必須采用numpy array 格式")
        assert graph.ndim == 2, print("二分圖的維度必須為2")
        self.garph = graph
        rows, cols = graph.shape
        self.rows = rows
        self.cols = cols

        # self.vx = np.zeros(cols, dtype=np.int32)   # visit flag, 橫向結點的訪問標志
        # self.vy = np.zeros(rows, dtype=np.int32)  # visit flag, 豎向結點的訪問標志

        self.match_index = np.ones(cols, dtype=np.int32) * -1  # 橫向結點匹配的豎向結點的index (默認-1,表示未匹配任何豎向結點)
        self.match_count = 0  # 總共有多少條匹配邊

    def match(self):
        for y in range(self.rows):  # 從每一豎向結點開始,尋找增廣路
            self.vx = np.zeros(self.cols, dtype=np.int32)  # visit flag, 橫向結點的訪問標志置0
            self.vy = np.zeros(self.rows, dtype=np.int32)  # visit flag, 豎向結點的訪問標志置0
            if self.dfs(y):
                self.match_count += 1  # 采用dfs尋找增廣路,如果找到,匹配邊加1
        return self.match_index, self.match_count

    def dfs(self, y):  # 遞歸版深度優先搜索
        self.vy[y] = 1
        for x in range(self.cols):
            if self.vx[x] == 0 and self.garph[y][x] == 1:  # 橫向結點x沒有訪問過,而且豎向結點y和橫向結點x有邊連接
                self.vx[x] = 1
                # 兩種情況:一是結點x沒有匹配,那么找到一條增廣路;二是X結點已經匹配,采用DFS,沿着X繼續往下走,最后若以未匹配點結束,則也是一條增廣路
                if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
                    self.match_index[x] = y  # 未匹配邊變成匹配邊
                    print(y, x, self.match_index)
                    return True
        return False
if __name__ == '__main__':
    graph = np.array([[0, 1, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 1, 0]])
    hungary = HungaryMatch(graph)
    index, count = hungary.match()
    print(index)  # [-1  1  2  0]:三組匹配邊(x, y): (1, 1), (2, 2), (3, 0)
    print(count)  # 3:共有三條匹配邊        

cpp實現如下:

參考:https://zhuanlan.zhihu.com/p/104901134?utm_source=wechat_session

bool dfs(int x){
   for(int i=0; i<m; i++){
      if (edge[x][i]==0 || vis[i]) continue;
      vis[i] = true;
      if (y_match[i]==-1 || dfs(y_match[i]))
           return true;
   }
   return false;
}

int cnt = 0;
for (int i=0; i<n; i++){
    memset(vis, false, sizeof(vis));
    if (dfs(i))
         cnt++;
}

3. KM算法(Kuhn-Munkres Algorithm)

參考:https://blog.sengxian.com/algorithms/km,https://piggerzzm.github.io/2020/03/28/Kuhn-Munkres/

3.1 可行頂標和相等子圖

二分圖最優匹配(最大權值匹配)的經典算法是由Kuhn和Munkres獨立提出的KM算法,值得一提的是最初的KM算法是在1955年和1957年提出的,因此當時的KM算法是以矩陣為基礎的,隨着匈牙利算法被Edmonds提出之后,現有的KM算法利用匈牙利樹可以得到更漂亮的實現。

KM算法是通過給每個頂點一個標號(叫做頂標,或者節點函數)來把求最大權完美匹配的問題轉化為求完美匹配的問題的。可以簡單理解為節點函數就是節點的一個值。幾個概念如下:

  • 頂標(節點函數):指的是圖中的每個頂點,給它賦予一個值(就像邊的權重值),這個值也稱為節點函數值。
  • 可行頂標:對於所有頂點的函數值\(l\),使得對於任意邊 \(e(x \rightarrow y)\),都滿足 \(l_{x} + l_{y} \ge W_{e}\),(其中,\(l_x\)為頂點x的頂標,\(l_y\)為頂點y的頂標,\(w_e\)為邊\(e(x \rightarrow y)\)的權值)
  • 相等子圖:相等子圖包含原圖中所有的點,但只包含滿足 \(l_{x} + l_{y} = W_{e}\)的所有邊 \(e(x \rightarrow y)\)。根據定義,這些邊一定是當前權值最大的邊(不等式已經取到等號),那么如果相等子圖有完美匹配,那這個完美匹配一定是最大權值完美匹配。因為相等子圖的權值和為所有點的頂標之和,而隨便一個匹配中的邊因為受到 \(W_{e} \le l_{x} + l_{y}\)的限制,不可能比所有點的頂標之和大。

3.2 KM算法步驟解析

KM算法的主要目標就在於尋找可行頂標,使得相等子圖有完美匹配。可行頂標的修改過程中,每一步都運用了貪心的思想,這樣我們的最終結果一定是最優的。下面是算法的敘述:

步驟一:頂標初始化

因為有 \(l_{x} + l_{y} = W_{e}\)恆成立,我們設左側(Y集)的所有節點頂標為 0,那么所有 X集的點的頂標就必須為從它出發所有的邊的權值最大值。

步驟二:尋找完美匹配

尋找當前頂標條件下, 采用增廣路定理對每個點進行匹配(匈牙利算法),若最大匹配就是完美匹配,結束算法,否則必須修改頂標,使得有更多的邊能夠參與進來。

步驟三:修改頂標,加入更多可行頂標及對應邊

我們求當前相等子圖的完美匹配失敗,是因為對於某個未匹配頂點 u,我們找不到一條從它出發的增廣路,這時我們只能獲得一條交替路。我們把 X集中在交替路的點集叫做 S, X集中不在交替路的點集叫做 S',同理 Y集中在交替路的點集叫做 T, Y集中不在交替路的點集叫做 T'。如果我們把交替路中 X 集頂點的頂標(點集S中的點)全都減小某個值 d,Y集的頂標(點集T中的點)全都增加同一個值 d,那么我們會發現:

  • 兩端都在交替路中的邊 \(e(i \rightarrow j)\)\(l_{i} + l_{j}\) 的值沒有變化。也就是說,它原來屬於相等子圖,現在仍屬於相等子圖。
  • 兩端都不在交替路中的邊 \(e(i \rightarrow j)\)\(l_{i}, l_{j}\) 都沒有變化,\(l_{i} + l_{j}\) 的值沒有變化。也就是說,它原來屬於(或不屬於)相等子圖,現在仍屬於(或不屬於)相等子圖。
  • X集一端在 S' 中, Y端在 T中的邊 \(e(i \rightarrow j)\),它的 \(l_{i}\)不變, \(l_{j}\)增加了d,\(l_{i} + l_{j}\)的值有所增大。它原來不屬於相等子圖,現在仍不可能屬於相等子圖。
  • X集一端在 S中,Y 端在 T'中的邊\(e(i \rightarrow j)\),它的 \(l_{i}\)減小了d, \(l_{j}\)不變,\(l_{i} + l_{j}\)的值有所減小。也就說,它原來不屬於相等子圖,現在可能進入了相等子圖,因而使相等子圖得到了擴大。

也就是說,只有 X集一端在 S 中,Y端在 T'中的邊才有可能被選中。繼續貪心,我們只能讓滿足條件的邊權最大的邊被選中,即滿足\(l_{x} + l_{y} = W_{e}\),那么這個 d 值,就應該取 \(d = \min\{l_{x} + l_{y} - W_{e(x\rightarrow y)}\ \vert \ x \in S, y \in T'\}\)

於是有新的邊加入相等子圖,我們可以愉快的繼續對於未匹配頂點 u尋找增廣路,這樣的修改最多進行n次,而一共有 n個點,所以除去修改頂標的時間,復雜度已經達到\(O(n^{2})\)。因此算法的復雜度主要取決於修改頂標的時間, 修改頂標主要兩個思路:

  • 思路一:枚舉所有\(n^{2}\)條邊,看是否滿足條件,滿足條件就更新d值。最直觀清晰,然而總的復雜度飆升至\(O(n^{4})\)
  • 思路二:對於T'​的每個點v,定義松弛變量\(slack(v) = \min\{l_{x}+l_{y} -W_{e(x\rightarrow y)}\ \vert\ x\in S\}\),這個松弛變量在匹配的過程中就可以更新,修改頂標的過程中\(d = \min\{slack(v)\ \vert\ v \in T'\}\)。總復雜度\(O(n^{3})\),但不是嚴格的(想一想為什么)?

3.3 KM算法步驟總結

KM算法僅僅只適用於找二分圖最佳完美匹配,如果無完美匹配,那么算法很可能陷入死循環(如果不存在的邊為 -INF 的話就不會,但正確性就無法保證了),對於這種情況要小心處理。
最后回顧一下總的流程,理一下思路:

  1. 初始化可行頂標。
  2. 用增廣路定理尋對每個點找匹配。
  3. 若點未找到匹配則修改可行頂標的值。
  4. 重復2、3步直到所有點均有匹配為止,即找到相等子圖的完美匹配為止

3.4 KM代碼實現

3.4.1 python實現

\(O(n^{4})\)版本:

# Kuhn-Munkres匹配算法, O(n^4)時間復雜度
class KMMatchOriginal(object):

    def __init__(self, graph):
        assert isinstance(graph, np.ndarray), print("二分圖的必須采用numpy array 格式")
        assert graph.ndim == 2, print("二分圖的維度必須為2")
        self.graph = graph

        rows, cols = graph.shape
        self.rows = rows
        self.cols = cols

        self.lx = np.zeros(self.cols, dtype=np.float32)  # 橫向結點的頂標
        self.ly = np.zeros(self.rows, dtype=np.float32)  # 豎向結點的頂標

        self.match_index = np.ones(cols, dtype=np.int32) * -1  # 橫向結點匹配的豎向結點的index (默認-1,表示未匹配任何豎向結點)
        self.match_weight = 0  # 匹配邊的權值之和

    def match(self):
        # 初始化頂標, ly初始化為0,lx初始化為節點對應權值最大邊的權值
        for y in range(self.rows):
            self.ly[y] = max(self.graph[y, :])

        for y in range(self.rows):  # 從每一豎向結點開始,尋找增廣路
            while True:
                self.vx = np.zeros(self.cols, dtype=np.int32)  # 橫向結點的匹配標志
                self.vy = np.zeros(self.rows, dtype=np.int32)  # 豎向結點的匹配標志
                if self.dfs(y):
                    break
                else:
                    self.update()
        return self.match_index

    # 更新頂標
    def update(self):
        d = np.inf
        # 尋找y中已匹配,x中未匹配,對應需要減小的最小權值
        for y in range(self.rows):
            if self.vy[y]:
                for x in range(self.cols):
                    if not self.vx[x]:
                        d = min(d, self.lx[x] + self.ly[y] - self.graph[y][x])

        for x in range(self.cols):  # x頂標初始化值為0,因此所有匹配點頂標+d
            if self.vx[x]:
                self.lx[x] += d
        for y in range(self.rows):  # y頂標初始化值為對應邊的最大權值,因此所有匹配點頂標-d
            if self.vy[y]:
                self.ly[y] -= d

    def dfs(self, y):  # 遞歸版深度優先搜索
        self.vy[y] = 1
        for x in range(self.cols):
            if self.vx[x] == 0 and self.lx[x] + self.ly[y] == self.graph[y][x]:
                self.vx[x] = 1
                # 兩種情況:一是結點x沒有匹配,那么找到一條增廣路;二是X結點已經匹配,采用DFS,沿着X繼續往下走,最后若以未匹配點結束,則也是一條增廣路
                if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
                    self.match_index[x] = y  # 未匹配邊變成匹配邊
                    return True
        return False
if __name__ == '__main__':
    graph = np.array([[2,1,1],[3,2,1],[1,1,1]])
    kmo = KMMatchOriginal(graph)
    print(kmo.match())

\(O(n^{3})\)版本:

# Kuhn-Munkres匹配算法
class KMMatch(object):

    def __init__(self, graph):
        assert isinstance(graph, np.ndarray), print("二分圖的必須采用numpy array 格式")
        assert graph.ndim == 2, print("二分圖的維度必須為2")
        self.graph = graph

        rows, cols = graph.shape
        self.rows = rows
        self.cols = cols

        self.lx = np.zeros(self.cols, dtype=np.float32)  # 橫向結點的頂標
        self.ly = np.zeros(self.rows, dtype=np.float32)  # 豎向結點的頂標

        self.match_index = np.ones(cols, dtype=np.int32) * -1  # 橫向結點匹配的豎向結點的index (默認-1,表示未匹配任何豎向結點)
        self.match_weight = 0  # 匹配邊的權值之和

        self.inc = math.inf

    def match(self):
        # 初始化頂標, lx初始化為0,ly初始化為節點對應權值最大邊的權值
        for y in range(self.rows):
            self.ly[y] = max(self.graph[y, :])

        for y in range(self.rows):  # 從每一豎向結點開始,尋找增廣路
            while True:
                self.inc = np.inf
                self.vx = np.zeros(self.cols, dtype=np.int32)  # 橫向結點的匹配標志
                self.vy = np.zeros(self.rows, dtype=np.int32)  # 豎向結點的匹配標志
                if self.dfs(y):
                    break
                else:
                    self.update()
                # print(y, self.lx, self.ly, self.vx, self.vy)
        return self.match_index

    # 更新頂標
    def update(self):
        for x in range(self.cols):
            if self.vx[x]:
                self.lx[x] += self.inc
        for y in range(self.rows):
            if self.vy[y]:
                self.ly[y] -= self.inc

    def dfs(self, y):  # 遞歸版深度優先搜索
        self.vy[y] = 1
        for x in range(self.cols):
            if self.vx[x] == 0:
                t = self.lx[x] + self.ly[y] - self.graph[y][x]
                if t == 0:
                    self.vx[x] = 1
                    # 兩種情況:一是結點x沒有匹配,那么找到一條增廣路;二是X結點已經匹配,采用DFS,沿着X繼續往下走,最后若以未匹配點結束,則也是一條增廣路
                    if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
                        self.match_index[x] = y  # 未匹配邊變成匹配邊
                        # print(y, x, self.match_index)
                        return True
                else:
                    if self.inc > t:
                        self.inc = t
        return False
if __name__ == '__main__':
    graph = np.array([[2, 1, 1], [3, 2, 1], [1, 1, 1]])
    # # graph = np.array([[3,4,6,4,9],[6,4,5,3,8],[7,5,3,4,2],[6,3,2,2,5],[8,4,5,4,7]])
    km = KMMatch(graph)
    print(km.match())

在代碼撰寫過程中,踩了幾個坑,也發現了一些問題,總結如下:

  • 在初始化頂標時,若行結點初始化為最大邊權值,列結點初始化為0,則必須從行結點出發,遍歷尋找滿足條件的增廣路,否則代碼會陷入死循環。(即從初始化為最大邊權值的結點開始遍歷
  • KM算法要求行結點和列結點個數相同,如果不相同時,保證行結點個數少,列結點個數多,然后通過padding來使行結點和列結點個數相同
  • KM算法求最大權值匹配,若要求最小權值匹配,可以對權值矩陣進行轉換,如采用一個很大值(如sys.maxint)減去權值矩陣
3.4.2 cpp代碼實現

\(O(n^{4})\)版本:

int Weight[maxm][maxn];
int Lx[maxm], Ly[maxn]; // 頂標
int match[maxn];    // 記錄匹配
bool S[maxm], T[maxn];  // 算法中的兩個集合S和T

// 步驟 1: 初始化可行頂標和初始化匹配
void Init()
{
    // 將X集合的頂標設為最大邊權,Y集合的頂標設為0
    for (int i = 1; i <= m; i++)
    {
        Lx[i] = 0;
        for (int j = 1; j <= n; j++)
        {
            match[j] = 0;   // match記錄的是Y集合里的點與誰匹配
            Ly[j] = 0;
            Lx[i] = max(Lx[i], Weight[i][j]);
        }
    }
}
//步驟2:增廣路定理尋找匹配點(匈牙利算法中的DFS)
bool findPath(int i)
{
    S[i] = true;
    for (int j = 1; j <= n; j++)
    {
        if (Lx[i] + Ly[j] == Weight[i][j] && !T[j]) // 找出在相等子圖里又還未被標記的邊
        {
            T[j] = true;
            if (!match[j] || findPath(match[j])) // 未被匹配,或者已經匹配又找到增廣路
            {
                match[j] = i;
                return true;
            }
        }
    }
    return false;
}

//步驟 3: 更新頂標
void update() 
{
    // 計算a
    int a = 1 << 30;
    for (int i = 1; i <= m; i++)
        if (S[i])
            for (int j = 1; j <= n; j++)
                if (!T[j])
                    a = min(a, Lx[i] + Ly[j] - Weight[i][j]);

    // 修改頂標
    for (int i = 1; i <= m; i++)
        if (S[i])
            Lx[i] -= a;
    for (int j = 1; j <= n; j++)
        if (T[j]) 
            Ly[j] += a;
}
// 整體的KM算法
void KM()
{
    Init();

    for (int i = 1; i <= m; i++)
    {
        while (true)
        {
            for (int i = 1; i <= m; i++)
                S[i] = 0;
            for (int j = 1; j <= n; j++)
                T[j] = 0;
            if (!findPath(i))
                update();
            else
                break;
        }
    }
}

\(O(n^{3})\)版本:

const int maxn = 500 + 3, INF = 0x3f3f3f3f;
int n, W[maxn][maxn];
int mat[maxn];
int Lx[maxn], Ly[maxn], slack[maxn];
bool S[maxn], T[maxn];

inline void tension(int &a, const int b) {
    if(b < a) a = b;
}

inline bool match(int u) {
    S[u] = true;
    for(int v = 0; v < n; ++v) {
        if(T[v]) continue;
        int t = Lx[u] + Ly[v] - W[u][v];
        if(!t) {
            T[v] = true;
            if(mat[v] == -1 || match(mat[v])) {
                mat[v] = u;
                return true;
            }
        }else tension(slack[v], t);
    }
    return false;
}

inline void update() {
    int d = INF;
    for(int i = 0; i < n; ++i)
        if(!T[i]) tension(d, slack[i]);
    for(int i = 0; i < n; ++i) {
        if(S[i]) Lx[i] -= d;
        if(T[i]) Ly[i] += d;
    }
}

inline void KM() {
    for(int i = 0; i < n; ++i) {
        Lx[i] = Ly[i] = 0; mat[i] = -1;
        for(int j = 0; j < n; ++j) Lx[i] = max(Lx[i], W[i][j]);
    }
    for(int i = 0; i < n; ++i) {
        fill(slack, slack + n, INF);
        while(true) {
            for(int j = 0; j < n; ++j) S[j] = T[j] = false;
            if(match(i)) break;
            else update();
        }
    }
}

參考:https://nymrli.top/2019/12/05/KM-Kuhn-Munkres-算法/

https://piggerzzm.github.io/2020/03/28/Kuhn-Munkres/

https://www.cnblogs.com/xingnie/p/10395788.html

4. Kuhn-Munkres算法開源包

在實際項目中涉及到最大權值匹配問題時,可以采用開源包中的Kuhn-Munkres算法,如下面兩個:

munkres

python有實現了munkres算法的安裝包,可以直接安裝:pip install munkres

官方使用文檔:https://software.clapper.org/munkres/

scipy

scipy模塊中scipy.optimize.linear_sum_assignment實現了KM匹配算法,可以直接調用。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM