【學習記錄】二分查找的C++實現,代碼逐步優化


二分查找的思想很簡單,它是針對於有序數組的,相當於數組(設為int a[N])排成一顆二叉平衡樹(左子節點<=父節點<=右子節點),然后從根節點(對應數組下標a[N/2])開始判斷,若值<=當前節點則到左子樹,否則到右子樹。查找時間復雜度是O(logN),因為樹的高度是logN。

二分查找分兩種:一種是精確查找某個元素x,另一種則是根據比較關系式查找,比如返回i使得對任意j<i均有a[j]<a[i],這用在折半插入排序中。

剛好前天無聊不借助STL手寫折半插入排序時,發現自己基本功不扎實,對比了下STL自己實現的二分查找,發現自己還是太嫩了。

函數簽名采用C風格的template <typename> T* f(T* a, size_t n, const T& x); 即查找區間為數組T a[n]整個區間,返回找到元素的地址。

下面給出對第一種查找的我第一感覺的實現方法

template <typename T>
T* binary_search(T* a, size_t n, const T& x)
{
    size_t low = 0;
    size_t high = n - 1;
    size_t mid = (low + high) / 2;
    while (low < mid)
    {
        if (x == a[mid])
            return a + mid;
        else if (x < a[mid])  // x位於[low, mid)區間
            high = mid - 1;   // 縮小查找范圍到[low, mid-1]
        else // x > a[mid]
            low = mid + 1;    // 縮小查找范圍到[mid+1, high]
        mid = (low + high) / 2;
    }
    return nullptr;  // 查找失敗
}

測試代碼如下

#include "binary_search.h"
#include <cstdio>

int main()
{
    const size_t N = 8;
    int a[N] = {1,2,3,4,5,6,7,8};
    auto p = binary_search(a, N, 3);
    if (p != nullptr)
        printf("a[%d] = %d\n", p - a, *p);
    return 0;
}

結果什么也沒輸出。調試發現在mid=2時由於low=2此時while循環退出,而此時本應該比較一下x和a[mid]的。

肉眼調試,初始:low=0, high=7, mid=(0+7)/2=3,進入while循環。

low=0<3=mid,a[3]=4>3,設置high=3,計算mid=(0+3)/2=1,進入下次循環

low=0<1=mid,a[1]=2<3,設置low=2,計算mid=(2+3)/2=2,進入下次循環

low=2=mid,跳出循環。返回默認返回值nullptr

可能這里會想說,那把while的條件改成<=不就行了?這樣會造成死循環,因為(a+b)/2>=a恆成立。

那么讓相等的時候比較一次就退出呢?

也不行,假如low=0,high=1,mid永遠等於0,最可怕的是,這時候不會跟a[1]進行比較。

正是第一感覺考慮到這個,所以我沒有用while (low < high),因為此時也會死循環。

——根本問題出現了,二分法每次都會對半切分區間,但是有時候區間大小(減1)為奇數,那么兩個子區間大小肯定不一樣。

對兩個數組下標low和high(low<=high),按照(low+high)/2得到的mid把數組划分成哪兩部分呢?

對於所有對半縮小區間問題,最后都會變成2種情況:1、low=mid=high;2、low=mid=high-1。

設k為整數:

假如low+high=2k,那么mid=k,左區間[low, mid-1]大小為mid-low=k-low,右區間[mid+1,high]大小為high-k=2k-low-k=k-low,左右區間相等;

假如low+high=2k+1,那么mid=k,左區間大小仍為k-low,右區間大小為2k+1-low-k=k-low+1,比左區間多1個。

也就是如果仍然以while (low<high)來判斷,那么在結束循環后,在return nullptr;之前要加下面幾行代碼判斷。

    if (a[mid] == x)  // low == mid
        return a + mid;
    if (low < high && a[high] == x)  // low == mid == high-1
        return a + high;

稍微顯得不太美觀了,能不能一個while循環就搞定還不用做額外判斷的呢?

考慮下STL采用的的左閉右開區間[first, last),對數組a[N]而言first=0,last=N,不用特地寫N-1。

假如first+last=2k,那么mid=k,左區間[first, mid)大小為mid-first=k-first,右區間[mid+1, last)大小為(2k-first)-(mid+1)=k-first-1,比左區間少1個。

假如first+last=2k+1,兩個區間一樣大(這里不給出計算流程了)。 給出這種情況下的代碼

T* binary_search(T* a, size_t n, const T& x)
{
    size_t first = 0;
    size_t last = n;
    size_t mid = (first + last) / 2;
    while (first < mid)
    {
        if (x == a[mid])
            return a + mid;
        else if (x < a[mid])
            last = mid;  // [first, mid)
        else // x > a[mid], [mid+1, last)
            first = mid + 1;
        mid = (first + last) / 2;
    }
    if (a[mid] == x)
        return a + mid;
    return nullptr;
}

測試代碼

    for (int x = 1; x <= 8; x++)
    {
        auto p = binary_search(a, N, x);
        if (p != nullptr)
            printf("a[%d] = %d\n", p - a, *p);
    }

結果無誤,突然覺得STL采用左閉右開區間有其道理了。雖然感覺代碼還是有些丑陋,再刪一行的話充其量就是寫成while ((mid = (first + last) / 2) > first)的形式,減少了while循環體內的一行重復代碼mid = (first + last) / 2;

這樣就完了?NO!這種實現健壯性不夠,因為會出現溢出!

從最開始我就錯了!假設,這里僅僅是假設,size_t的上限為3,即2 bits 無符號整數。當first為1(01),last為3(11)時,兩者相加會溢出(超出size_t能表示的范圍[0, 4))

理想的結果本來是2,實際結果(假如運算規則是超出最高位的進位直接省略掉)

二進制運算01 + 11 = 100,省略最高位1,變成了00,然后00除以2,結果是00,即0。也就是(1+3)/2的結果不是2,而是0!

當然,像我這種比較隨意的簡單測試不會出現溢出情況,但是溢出的風險必須考慮!(想起當時刷了道二分法的leetcode題應該就是掛在這里了)

因此:mid = (first + last) / 2要改成mid = first + (last - first) / 2

 

本着學習的態度,去看了官網上的binary_search的實現http://www.cplusplus.com/reference/algorithm/binary_search/

template <class ForwardIterator, class T>
  bool binary_search (ForwardIterator first, ForwardIterator last, const T& val)
{
  first = std::lower_bound(first,last,val);
  return (first!=last && !(val<*first));
}

調用了lower_bound函數,然后去看下lower_bound。先不看代碼,這里我順便去看了<algorithm>中除了lower_bound還有upper_bound,很直白的意思,下界和上界,按照STL的設計應該也是左閉右開,用測試代碼描述如下:

#include <cstdio>
#include <algorithm>

int main()
{
    const size_t N = 5;
    int a[N] = { 1,2,2,2,2 };
    auto itL = std::lower_bound(a, a + N, 2);
    auto itR = std::upper_bound(a, a + N, 2);
    printf("itL == a[%d]\n", itL - a);  // itL == a[1]
    printf("itR == a[%d]\n", itR - a);  // itR == a[5]
    return 0;
}

注釋部分為該行輸出結果,lower_bound返回第1個與查找值相等的迭代器,upper_bound返回lower_bound開始第1個與查找值不等的迭代器。

這么說不太嚴密,因為數組中可能不存在與查找值相等的值。那么此時會返回什么?

測試代碼如下

    const size_t N = 5;
    double a[N] = { 1,2,2,2,2 };
    auto itL = std::lower_bound(a, a + N, 0);
    auto itR = std::upper_bound(a, a + N, 0);
    printf("itL == a[%d]\n", itL - a);  // itL == a[0]
    printf("itR == a[%d]\n", itR - a);  // itR == a[0]
    itL = std::lower_bound(a, a + N, 1.5);
    itR = std::upper_bound(a, a + N, 1.5);
    printf("itL == a[%d]\n", itL - a);  // itL == a[1]
    printf("itR == a[%d]\n", itR - a);  // itR == a[1]

嚴密又簡潔點講,在lower_bound左邊的值都<查找值,在upper_bound左邊的值都<=查找值。也就是我在開始提到的二分查找的第二種情況。

好了,那么回顧binary_search的代碼,它調用的是lower_bound,第一行返回第1個與查找值相等的迭代器並賦值給first,若不存在則first為第1個大於查找值的迭代器。

第二行返回(first!=last && !(val<*first)),顯然兩個bool表達式同時成立等價於數組中含有查找值。

lower_bound什么時候表示能找到值?——當然是返回的迭代器對應值等於查找值,因為查找失敗時,返回的是比該值大的迭代器,如果是我的話會直接寫一句return val==*first。——但是不對,因為如果整個區間[first, last)的值都比查找值小,那么返回的是last,一個無法訪問的迭代器,對*first的比較就會出錯。所以前面加了一句first != last。那么后面為什么用!(val<*first)(等價於*first<=val)來判斷呢?

問題等價於——什么時候查找成功並且*first<val?

關於這點,我仔細思考了一下,不知回答是否正確。

首先,這種情況是不存在的;

其次,這么寫是因為這是函數模板,可能進行比較的是類(而非基本數據類型),而這里只要求類重載了operator<用來比較,對於operator==甚至operator>都可有可無

 

OK,現在來看看lower_bound的實際實現

template <class ForwardIterator, class T>
  ForwardIterator lower_bound (ForwardIterator first, ForwardIterator last, const T& val)
{
  ForwardIterator it;
  iterator_traits<ForwardIterator>::difference_type count, step;
  count = distance(first,last);
  while (count>0)
  {
    it = first; step=count/2; advance (it,step);
    if (*it<val) {                 // or: if (comp(*it,val)), for version (2)
      first=++it;
      count-=step+1;
    }
    else count=step;
  }
  return first;
}

 翻譯成我之前的函數簽名即如下代碼(主要用於表達意思,沒考慮細枝末葉的優化)

template <typename T>
T* binary_search(T* a, size_t n, const T& x)
{
    T* first = a;         // 搜索區間起始位置(左閉)
    T* last = a + n;      // 搜索區間結束位置(右開)
    ptrdiff_t count = n;  // 搜索區間元素數量
    while (count > 0)
    {
        ptrdiff_t step = count / 2;
        T* mid = first + step;  // 二分點
        if (*mid < x) {  // 繼續查找[mid+1, last)
            first = mid + 1;
            count -= step + 1;
        }
        else
            count = step;   // 繼續查找[first, mid)
    }
    return first;
}

這里的關鍵點是while循環里的變成了區間數量count,而count不能簡單地用last-first來代替,即使它的初始值為last-first!

依舊考慮特殊情況:對數組a[2],first=0,last=2,count=2-0=2,進入while循環

count=2>0,計算step=2/2=1;mid=first+1,*mid即a[1];

若a[1]<x則需要查找[2,2),此時first變為mid+1=first+2,count變為2-(1+1)=0,last-first等於count;

否則,需要查找[0,1),此時count變為1,但是last-first=2不等於count!

說白了是把用first和last表示區間[first, last)改成用first和count表示區間[first, first+count)

這么一想,用first和last一樣能寫出這樣的代碼啊!於是我嘗試着改了下

template <typename T>
T* binary_search(T* a, size_t n, const T& x)
{
    T* first = a;         // 搜索區間起始位置(左閉)
    T* last = a + n;      // 搜索區間結束位置(右開)
    while (last - first > 0)
    {
        ptrdiff_t step = (last - first) / 2;
        if (first[step] < x)   // first + step為二分點mid
            first += (step + 1);  // 繼續查找[mid+1, last)
        else
            last = first + step;  // 繼續查找[first, mid)
    }
    return first;
}

這樣的代碼看起來思路更加自然,而且沒有什么漏洞(不考慮對類型T的要求的話)

再反思之前的做法,用起始位置和mid比較是不合適的,迭代的終止條件應該是搜索子區間元素數量大於0

我的實現思路從第一步就錯了!也沒必要去計算左區間大還是右區間大!因為終止條件是“當前區間”不為空!而不需要比較二分點和下界或者上界。

最后給出最終的二分搜索代碼(使用first和last表示左閉右開區間)

/**
* 功能: 在升序排好的數組a[n]用二分法查找元素x的位置
* 參數:
*   @a 數組首地址
*   @n 數組元素數量
*   @x 要查找的元素
* 返回: 若a[n]中存在元素x,返回任意一個與x相等的數組元素地址; 否則返回nullptr.
*/
template <typename T>
T* binary_search(T* a, size_t n, const T& x)
{
    size_t first = 0;
    size_t last = n;
    while (last - first > 0)
    {
        size_t mid = first + (last - first) / 2;
        if (a[mid] == x)
            return a + mid;
        else if (a[mid] < x)
            first = mid + 1;
        else
            last = mid;
    }
    return nullptr;
}

用閉區間low和high就稍微復雜點,因為high可能為-1,如果不把high的類型改為int,while里面需要額外判斷high!=-1,即while (high - low >= 0 && high != -1)

測試代碼和測試結果如下

#include "binary_search.h"
#include <cstdio>

int main()
{
    const size_t N = 16;
    int a[N] = { 1,2,3,4,5,6,7,8 };
    for (size_t i = 0; i < N; i++)
        a[i] = i + 1;
    for (int x = 0; x <= 30; x++)
    {
        auto p = binary_search(a, N, x);
        if (p != nullptr)
            printf("a[%d] = %d\n", p - a, *p);
        else
            printf("%d not found!\n", x);
    }
    return 0;
}

這篇博客雖然很啰嗦,而且基本功稍微扎實的都能看出我講了一堆廢話,但是主要目的還是記錄我從錯誤的思路轉向正確的過程,順便溫故了STL關於二分查找的函數。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM