[算法筆記] 回溯法總結


本文復習一下回溯法,包括遞歸型和非遞歸型,通過下面 2 個例子來解析回溯法:

  • 全排列問題
  • n 皇后問題
  • 三着色問題

回溯法

在許多遞歸問題當中,我們采取的方法都是窮盡所有的可能,從而找出合法的解。但是在某些情況下,當遞歸到某一層的時候,根據設置的判斷條件,可以 judge 此解是不合法的。在這種情況下,我們就沒必要再進行深層次的遞歸,從而可以提高算法效率。這一類算法我們稱為“回溯法”,設置的判斷條件稱為“剪枝函數”。

回溯法的遞歸形式:

Input : X = {X1, X2, ..., Xn}
Output: T = (t1, t2, ..., tn)

back-track-rec(int now)
{
    for x0 in X
    {
        T[now] = x0
        if (T[0...now] is valid)  //如果有效則進行,否則嘗試下一個x0
        {
            if (now == n)  //是完整解
            {
                print(T[1...now]);
                return;
            }
            else if (now < n)  //是部分解
            {
                back-track-rec(now+1);
            }
        }
    }
}

在可計算理論中,有這么一個結論:

所有的遞歸函數都能轉換為迭代,但迭代不一定能轉換為遞歸。

我們知道,C語言當中,函數調用是通過棧來實現的。遞歸實質是不斷進行函數調用,直至參數達到遞歸的邊界。所以,理論上,只要允許使用棧,那么回溯法就可以通過迭代實現。

回溯法的非遞歸形式:

Input : X = {X1, X2, ..., Xn}
Output: T = (t1, t2, ..., tn)

back-track-itor()
{
    int top = 0;
    while (top >= 0)
    {
        while T[top] 沒有取盡 X 中的元素
        {
            T[top] = next(X)
            if (check(top) is valid)
            {
                if (top < N)    //部分解
                    print();
                else
                    top++;
            }
        }
        reset(T[top])
        top--
    }
}

使用一句話來描述回溯法的思想:對於 T[i], 嘗試每一個 x0, 只要 T[i]=x0 有效,則對 T[i+1] 進行嘗試,否則回退到 T[i-1] .

全排列問題

給出一個 N ,輸出 N 的全排列。

首先,根據回溯法的遞歸形式的模板,可以寫出下面的代碼:

void backTrackRec2(int now)
{
    for (int i = 1; i <= N; i++)
    {
        a[now] = i;
        if (check(now))
        {
            if (now == N - 1)
            {
                print(N);
                return;
            }
            else
            {
                backTrackRec2(now + 1);
            }
        }
    }
}

而關鍵就是如何實現 check 函數去檢查是否當前填入的 i 是否有效,全排列的 check 函數很簡單:只需要 a[0...now-1] 都與 a[now] 不相等。

bool check(int now)
{
    for (int i = 0; i < now; i++)
    {
        if (a[i] == a[now])
            return false;
    }
    return true;
}

現在分析一下算法復雜度,對於每一個排列,需要對 a[0,...,(N-1)] 都執行一次 check,那么求解一個序列的復雜度為:

0 + 1 + 2 + ... + (n-1) = n(n-1)/2

現在思考如何把 check 的方法簡化:開一個長度為 N+1bool 數組 table[] ,如果數字 k 已經被使用了,那么置 table[k] = true 。復雜度為 O(1)

void backTrackRec1(int a[], int N, int now)
{
    if (now == N)
    {
        print(N);
        return;
    }
    for (int x = 1; x <= N; x++)
    {
        if (table[x] == false)
        {
            a[now] = x, table[x] = true;
            backTrackRec1(a, N, now + 1);
            table[x] = false;
        }
    }
}

最后給出非遞歸形式的解法,a[] 相當於一個棧,k 是棧頂指針,k++ 表示進棧, k-- 表示出棧(也是回溯的過程)。

void backTrackItor()
{
    int k = 0;
    while (k >= 0)
    {
        while (a[k] < N)
        {
            a[k]++;
            if (check(k))
            {
                if (k == N - 1)
                {
                    print(N);
                    break;
                }
                else
                {
                    k++;
                }
            }
        }
        a[k] = 0;
        k--;
    }
}

n 皇后問題

使用數組 pos[N] 來表示皇后的位置,pos[i] = j 表示第 i 個皇后在位置 (i,j)

首先來看遞歸形式的解法:

void backTrackRec(int now)
{
    if (now == N)
    {
        print();
        return;
    }
    for (int x = 0; x < N; x++)
    {
        pos[now] = x;
        if (check(now))
        {
            backTrackRec(now + 1);
        }
    }
}

我們使用 pos 數組來記錄位置,已經能保證每個皇后在不同的行上。因此,在 check 函數當中,需要檢查新添的皇后是否有同列或者在對角線上(兩點斜率為 1 )的情況。

bool check(int index)
{
    for (int i = 0; i < index; i++)
    {
        if (pos[i] == pos[index] || abs(i - index) == abs(pos[i] - pos[index]))
            return false;
    }
    return true;
}

再來看非遞歸的解法:

void backTrackItor()
{
    int top = 0;
    while (top >= 0)
    {
        while (pos[top] < N)
        {
            pos[top]++;
            if (check(top))
            {
                if (top == N-1)
                {
                    print();
                    break;
                }
                else
                {
                    top++;
                }
                
            }
        }
        pos[top--] = 0;
    }
}

本質上 n 皇后問題還是在做全排列的枚舉,但是因為 check 函數的不同,實際上空間復雜度要小一些。例如當出現:「1 2」 這種情況,就會被剪枝函數 check 裁去,不再進行深一層的搜索。

三着色問題

三着色問題是指:給出一個無向圖 G=(V,E), 使用三種不同的顏色為 G 中的每一個頂點着色,使得沒有 2 個相鄰的點具有相同的顏色。

首先,我們使用如下的數據結構:

map<int, vector<int>> graph;  //圖的鄰接鏈表表示
int v, e;  //點數,邊數
int table[VMAX]; //table[i]=0/1/2, 表示點 i 塗上顏色 R/G/B

很自然的想法,我們會窮舉每一個顏色序列,找出合法的解,假設有 3 個頂點,那么自然會這樣嘗試:

0 0 0
0 0 1
0 0 2
...

但是,這樣的窮舉並不是想要的結果,因為嘗試的過程中沒有加入 “沒有 2 個相鄰的點具有相同的顏色” 這樣的判斷。

還是直接套回溯法的模板:


void colorRec(int now)
{
    for (int i = 0; i < NCOLOR; i++)
    {
        table[now] = i;
        if (check(now))
        {
            if (now == v - 1) //完整解
            {
                print(v);
                countRec++;
                //不應有 return;
            }
            else
            {
                colorRec(now + 1);
            }
        }
    }
    table[now] = -1;
}
void colorItor()
{
    int top = 0;
    while (top >= 0)
    {
        while (table[top] < (NCOLOR - 1))
        {
            table[top]++;
            if (check(top))
            {
                if (top == v - 1)
                {
                    print(v);
                    countItor++;
                    // 不應有 break;
                }
                else
                {
                    top++;
                }
            }
        }
        table[top--] = -1;
    }
}

注意上面兩處的「不應有」,這是與全排列和 n 皇后有所區別的地方。為什么呢?

假設現有 4 個頂點:

A-----B
|     
C-----D

一個合法的着色序列為:

0 1 2 0

如果對應的地方有 break 或者 return,那么上述序列就會回溯到「0 1 2」這個序列,但是實際上,在上面序列的基礎上繼續搜索,可以找到:

0 1 2 1

這也是一個合法的着色序列,如果加入 breakreturn ,這種情況就被忽略了。

附錄

3着色代碼

#include <cstring>
#include <iostream>
#include <map>
#include <vector>
#define NCOLOR 3
#define VMAX 100
#define EMAX 200
using namespace std;
map<int, vector<int>> graph;
int v, e;
int table[VMAX]; //table[i]=R/G/B, 表示點 i 塗上顏色 R/G/B
int countRec = 0, countItor = 0;
bool check(int now)
{
    for (int x : graph[now])
    {
        if (table[x] != -1 && table[x] == table[now])
            return false;
    }
    return true;
}
void print(int len)
{
    cout << "Point: ";
    for (int i = 0; i < len; i++)
    {
        cout << i << ' ';
    }
    cout << endl;
    cout << "Color: ";
    for (int i = 0; i < len; i++)
    {
        cout << table[i] << ' ';
    }
    cout << "\n"
         << endl;
}
void colorRec(int now)
{
    for (int i = 0; i < NCOLOR; i++)
    {
        table[now] = i;
        if (check(now))
        {
            if (now == v - 1) //完整解
            {
                print(v);
                countRec++;
                //不應有 return;
            }
            else
            {
                colorRec(now + 1);
            }
        }
    }
    table[now] = -1;
}
void colorItor()
{
    int top = 0;
    while (top >= 0)
    {
        while (table[top] < (NCOLOR - 1))
        {
            table[top]++;
            if (check(top))
            {
                if (top == v - 1)
                {
                    print(v);
                    countItor++;
                    // 不應有 break;
                }
                else
                {
                    top++;
                }
            }
        }
        table[top--] = -1;
    }
}
int main()
{
    memset(table, -1, sizeof(table));
    cin >> v >> e;
    int a, b;
    for (int i = 0; i < e; i++)
    {
        cin >> a >> b;
        graph[a].push_back(b);
        graph[b].push_back(a);
    }
    // colorRec(0);
    memset(table, -1, sizeof(table));
    colorItor();
    cout << countRec << " " << countItor << endl;
}

/*
Sample1:
5 7
0 1
0 2
1 3
1 4
2 3
2 4
3 4
 */

全排列代碼

#include <iostream>
#include <cstring>
#define MAXN 20
using namespace std;
int a[MAXN] = {0};
bool table[MAXN] = {0};
int N = 0;
void print(int n)
{
    for (int i = 0; i < n; i++)
    {
        cout << a[i] << ' ';
    }
    cout << endl;
}
bool check(int now)
{
    for (int i = 0; i < now; i++)
    {
        if (a[i] == a[now])
            return false;
    }
    return true;
}
void backTrackRec1(int a[], int N, int now)
{
    if (now == N)
    {
        print(N);
        return;
    }
    for (int x = 1; x <= N; x++)
    {
        if (table[x] == false)
        {
            a[now] = x, table[x] = true;
            backTrackRec1(a, N, now + 1);
            table[x] = false;
        }
    }
}
void backTrackRec2(int now)
{
    for (int i = 1; i <= N; i++)
    {
        a[now] = i;
        if (check(now))
        {
            if (now == N - 1)
            {
                print(N);
                return;
            }
            else
            {
                backTrackRec2(now + 1);
            }
        }
    }
}
void backTrackItor()
{
    int k = 0;
    while (k >= 0)
    {
        while (a[k] < N)
        {
            a[k]++;
            if (check(k))
            {
                if (k == N - 1)
                {
                    print(N);
                    break;
                }
                else
                {
                    k++;
                }
            }
        }
        a[k] = 0;
        k--;
    }
}

int main()
{
    N = 3;
    for (int i = 1; i <= N; i++)
    {
        a[i - 1] = 0;
    }
    // backTrackRec1(a, N, 0);
    // backTrackRec2(0);
    backTrackItor();
}

n皇后代碼

#include <iostream>
using namespace std;
#define N 8
int count = 0;
int pos[N] = {0};
void print()
{
    count++;
    for (int i = 0; i < N; i++)
    {
        int r = i;
        int c = pos[i];
        for (int i = 0; i < c; i++)
            cout << "* ";
        cout << "Q ";
        for (int i = c + 1; i < N; i++)
            cout << "* ";
        cout << endl;
    }
    cout << endl;
}

bool check(int index)
{
    for (int i = 0; i < index; i++)
    {
        if (pos[i] == pos[index] || abs(i - index) == abs(pos[i] - pos[index]))
            return false;
    }
    return true;
}

void backTrackRec(int now)
{
    if (now == N)
    {
        print();
        return;
    }
    for (int x = 0; x < N; x++)
    {
        pos[now] = x;
        if (check(now))
        {
            backTrackRec(now + 1);
        }
    }
}
void backTrackItor()
{
    int top = 0;
    while (top >= 0)
    {
        while (pos[top] < N)
        {
            pos[top]++;
            if (check(top))
            {
                if (top == N-1)
                {
                    print();
                    break;
                }
                else
                {
                    top++;
                }
                
            }
        }
        pos[top--] = 0;
    }
}
int main()
{
    // backTrackRec(0);
    backTrackItor();
    cout << count << endl;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM