用Python實現最大堆


  本文的內容是如何通過二叉樹實現一個最大堆, 實現原理方面參考了Python的heap模塊. 此外, 在正式項目上, 我還是建議你使用python自帶的heap完成, 它只提供最小堆, 但是可以通過對所有元素取反或者重寫__lt__方法實現最大堆.

一. 堆的數據結構

1. 數據結構分析

  堆的本質就是一顆二叉樹, 這顆二叉樹必須具備以下兩個性質:

 1). 對於最大堆來說, 二叉樹根節點的值不小於任何子節點, 其所有子樹也符合這一特征, 最小堆則相反;

 2). 堆是一顆完全二叉樹, 除了底層外, 所有層都盡可能地填滿, 底層元素從左到右排列.

  上圖就是一個最大堆的二叉樹, 基於特性1我們可以得知, 這顆二叉樹從任意葉子節點到根節點的路徑一定是一個遞增序列, 最大值為根節點. 因此, 當我們需要最大值時, 取出根節點的值就行了. 當我們新添加了一個葉子節點之后, 為了維護二叉樹的有序性, 我們可以讓這個葉子節點向頂端移動, 如下圖所示:

->->

我們插入節點16后, 將這個節點的值與其父節點進行比較, 大於父節點則二者交換, 持續這個操作直到不大於父節點或沒有父節點為止, 這樣, 我們就在插入元素之后, 仍然保持了二叉樹的有序性. 彈出節點同理, 將底層最后一個葉子節點取出填入空缺, 然后根據值的大小讓這個節點往下移動就行.

  因此, 堆在保證內部有序性的前提下, 可以做到在O(k)的時間內插入和彈出元素, k為二叉樹的高度. 這也就是為什么堆的二叉樹必須是完全二叉樹: 在這種情況下k最小, 為log n. 因此, 堆的插入和彈出都只需要O(log n)的時間復雜度, 可以高效地獲取最大值/最小值.

 

2. 通過列表實現二叉樹

  由於堆是一顆完全二叉樹, 因此我們可以用一個列表來儲存這顆二叉樹的值:

  如上圖所示, 我們用列表從上到下, 從左到右記錄了二叉樹的所有節點. 二叉樹節點右邊的藍色數字是它在列表中的索引. 因此我們可以得知, 對於一個在列表中索引為n的節點, 它的父節點索引為(n-1)//2, 它的左右子節點索引為n*2+1和n*2+2, 如果索引值溢出, 說明沒有對應的父節點或子節點. 這樣, 我們就通過列表儲存了這顆完全二叉樹的信息.

  基於以上的分析, 我們先定義一個Heap類:

class Heap:

    def __init__(self, nums: [int] = None) -> None:
        self.cache = nums or []
        self._heapify()

    def __len__(self) -> int:
        return len(self.cache)

    def __bool__(self) -> bool:
        return len(self) > 0

    def __repr__(self) -> str:
        return f'heap({self.cache})'

    @property
    def largest(self) -> int:
        if not self.cache:
            raise Exception('Empty heap')
        return self.cache[0]

    def show(self) -> None:
        # 調用這個函數繪制一顆二叉樹出來,DEBUG用
        height = int(math.log2(len(self))) + 1
        for i in range(height):
            width = 2 ** (height - i) - 2
            print(' ' * width, end='')
            blank = ' ' * (width * 2 + 2)
            print(
                blank.join(['{: >2d}'.format(num) for num in self.cache[2 ** i - 1:min(2 ** (i + 1) - 1, len(self))]]))
            print()

    def _swap(self, i: int, j: int) -> None:
        # 這個方法交換二叉樹的兩個節點
        self.cache[i], self.cache[j] = self.cache[j], self.cache[i]

二. 插入元素

  這部分好像太簡單了, 我實在講不出來什么:

    def push(self, num: int) -> None:
        self.cache.append(num)
        self._siftup(self.size - 1)

    def _siftup(self, i: int) -> None:
        while i > 0:
            parent = (i - 1) >> 1
            if self.cache[i] <= self.cache[parent]:
                break
            self._swap(i, parent)
            i = parent

說白了, 當我們push一個元素時, 首先把這個元素放到列表的末端, 這相當於在完全二叉樹上新建了一個葉子節點. 然后, 調用siftup方法讓這個節點一直和父節點比較, 大於父節點就上浮, 直到它到達合適的位置. 這樣就維護了二叉樹的有序性.

三. 彈出元素

  彈出元素的原理和插入元素大同小異: 我們將根節點的元素彈出后, 取出最后一個葉子節點作為根節點(避免破壞完全二叉樹的結構), 然后讓這個節點與子節點比較, 下沉到合適的位置就行. 有兩點需要注意一下: 首先, 最大元素處在列表的頭部, 彈出的時間復雜度是O(n), 因此我們可以把頭部元素和尾部元素交換后, 刪除尾部元素. 然后, 大部分節點都有兩個子節點, 我們應該讓更大的那個節點上浮, 這樣才能保證二叉樹的有序性.

  基於以上兩點, 彈出元素的代碼如下:

    def pop(self) -> int:
        largest = self.largest
        self._swap(0, len(self) - 1)
        self.cache.pop()
        self._siftdown(0)
        return largest

    def _siftdown(self, i: int) -> None:
        while i * 2 + 1 < len(self):
            smaller = i
            if self.cache[i * 2 + 1] > self.cache[smaller]:
                smaller = i * 2 + 1
            if i * 2 + 2 < len(self) and self.cache[i * 2 + 2] > self.cache[smaller]:
                smaller = i * 2 + 2
            if smaller == i:
                return
            self._swap(i, smaller)
            i = smaller

四. 列表的堆化

  我們在創建Heap對象時傳入了一個列表作為堆的原始數據, 但是, 這個列表並不一定是顆有序的二叉樹, 因此我們需要將其堆化.

  最容易想到的方式是, 首先創建一個空堆, 然后將列表的所有元素依次推入堆中, 通過_siftup方法保持有序:

如上圖所示, 如果我們通過_siftup來堆化所有元素, 則時間復雜度為O(n/2*log n+n/4*log n/2+...+1*1)=O(nlog n), 這和排序的時間復雜度差不多, 因此不是很理想.

  另外一種方案是, 首先按照列表的原有順序構建二叉樹, 然后從二叉樹的倒數第二層開始, 依次通過_siftdown下沉, 這樣依次為k-1層, k-2層直到頂層排序:

  這種堆化方式的時間復雜度為O(n), 計算過程如下:

T(n)=O(n/4)+O(n/8*2)+(n/16*3)+O(log n)
2*T(n)=O(n/2)+O(n/4*2)+(n/8*3)+O(2*log n)
2*T(n)-T(n)=O(n/2)+O(n/4)+O(n/8)+...+O(log n)=O(n)

   因此, 我們的堆化方法可以這么寫:

    def _heapify(self) -> None:
        for i in reversed(range(len(self) // 2)):
            self._siftdown(i)

五. 總結

  簡單對我們創建的Heap類進行測試:

nums = list(range(14))
random.shuffle(nums)
heap = Heap(nums[:])
heap.show()
heap.push(100)
print('插入100')
heap.show()
heap.pop()
print('彈出堆頂元素')
heap.show()
for _ in range(100):
    num = random.randrange(100)
    nums.append(num)
    heap.push(num)
    assert max(nums) == heap.largest
    nums.remove(heap.pop())

print('所有測試通過!!!')

結果如下:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM