Java - PriorityQueue


JDK 10.0.2

前段時間在網上刷題,碰到一個求中位數的題,看到有網友使用PriorityQueue來實現,感覺其解題思想挺不錯的。加上我之前也沒使用過PriorityQueue,所以我也試着去讀該類源碼,並用同樣的思想解決了那個題目。現在來對該類做個總結,需要注意,文章內容以算法和數據結構為中心,不考慮其他細節內容。如果小伙伴想看那個題目,可以直接跳轉到(小測試)。

目錄


 一. 數據結構

我只列出了講解需要的重要屬性,不考慮其他細節。PriorityQueue(優先隊列)內部是以來實現的。為了描述方便,接下來的內容我將用pq[ ]代替queue[ ]

PriorityQueue<E> {
    /* 平衡二叉堆 用於存儲元素
     * n : 0 -> size-1
     * pq[n].left = pq[2*n+1]
     * pq[n].right = pq[2*(n+1)]
     */
    Object[] queue; 
    int size; // pq中元素個數
    Comparator<? super E> comparator; // 自定義比較器
}

回到目錄

二. 初始化(堆化)

如果使用已有集合來構造PriorityQueue,就會用到heapify()來對pq[ ]進行初始化(即:二叉堆化),使其滿足堆的性質。而heapify()又通過調用siftDownComparable(k, e)來完成堆化。源碼如下:

 1 @SuppressWarnings("unchecked")
 2 private void heapify() {
 3     final Object[] es = queue;
 4     int i = (size >>> 1) - 1;
 5     if (comparator == null)
 6         for (; i >= 0; i--)
 7             siftDownComparable(i, (E) es[i]);
 8     else
 9         for (; i >= 0; i--)
10             siftDownUsingComparator(i, (E) es[i]);
11 }
12 
13 @SuppressWarnings("unchecked")
14 private void siftDownComparable(int k, E x) {
15     Comparable<? super E> key = (Comparable<? super E>)x;
16     int half = size >>> 1;        // loop while a non-leaf
17     while (k < half) {
18         int child = (k << 1) + 1; // assume left child is least
19         Object c = queue[child];
20         int right = child + 1;
21         if (right < size &&
22             ((Comparable<? super E>) c).compareTo((E) queue[right]) > 0)
23             c = queue[child = right];
24         if (key.compareTo((E) c) <= 0)
25             break;
26         queue[k] = c;
27         k = child;
28     }
29     queue[k] = key;
30 }
View Code

如果有自定義比較器的話,調用:siftDownUsingComparator(k, e),否則調用:siftDownComparable(k, e)。這兩個方法只是在比較兩個元素大小時的表現形式不同,其他內容相同,所以我們只需要看其中一種情況就行。為了描述方便,下面的例子中,我使用Integer作為pq[ ]存儲元素類型,所以調用的是siftDownComparable(k, e)(size >>> 1 表示 size 無符號右移1位,等價於size / 2)

我不會去細摳源碼,一行一行地為大家講解,而是盡量使用簡單的例子來展示,我覺得通過例子以及后期大家自己閱讀源碼,會更容易理解算法內容。

現在我們來看看,使用集合{2, 9, 8, 4, 7, 1, 3, 6, 5}來構造PriorityQueue的過程。算法時間復雜度為O(n),n = size。(時間復雜度證明:《算法導論》(第3版)第6章6.3建堆)

  • 首先,從下到上,從右到左,找到第一個父結點 i,滿足規律:i = (size >>> 1) - 1,這里size = 9,i = 3;
  • 比較pq[3, 7, 8]中的元素,將最小的元素pq[x]與堆頂元素pq[3]互換,由於pq[x] = pq[3],所以無互換;
  • 移動到下一個父結點 i = 2,同理,比較pq[2, 5, 6]中的元素,將最小的元素pq[5]與pq[2]互換,后面的操作同理;
  • 需要注意,當pq[1](9)和pq[3](4)互換后(如圖2.d),pq[3, 7, 8]違背了最小堆的性質,所以需要進一步調整(向下調整),當調整到葉結點時(i >= size/2)結束

回到目錄

三. 添加元素

添加元素:add(e),offer(e),由於添加元素可能破壞堆的性質,所以需要調用siftUp(i, e)向上調整來維護堆性質。同樣,siftUp(i, e)根據有無自定義比較器來決定調用siftUpUsingComparator(k, e)還是siftUpComparable(k, e)。在我舉的例子中,使用的是siftUpComparable(k, e)。下面是添加元素的相關源碼:

 1 public boolean offer(E e) {
 2     if (e == null)
 3         throw new NullPointerException();
 4     modCount++;
 5     int i = size;
 6     if (i >= queue.length)
 7         grow(i + 1);
 8     siftUp(i, e);
 9     size = i + 1;
10     return true;
11 }
12 
13 @SuppressWarnings("unchecked")
14 private void siftUpComparable(int k, E x) {
15     Comparable<? super E> key = (Comparable<? super E>) x;
16     while (k > 0) {
17         int parent = (k - 1) >>> 1;
18         Object e = queue[parent];
19         if (key.compareTo((E) e) >= 0)
20             break;
21         queue[k] = e;
22         k = parent;
23     }
24     queue[k] = key;
25 }
View Code

源碼中 grow(i + 1) 是當pq[ ]容量不夠時的增長策略,目前可以不用考慮。現在來看往最小堆 pq = {3, 5, 6, 7, 8, 9} 中添加元素 1的過程。算法時間復雜度為O(lgn),n = size。

  • 首先,把要添加的元素 1 放到pq[size],然后調用siftUp(k, e)來維護堆,調整結束后 size++;
  • 向上調整(k, e)時,先找到結點pq[k]的父結點,滿足規律 parent = (k - 1) >>> 1,例子中,k = 6, parent = 2;
  • 比較pq[k]與pq[parent],將較小者放到高處,較大者移到低處,例子中,交換pq[6](1)與pq[2](6)的位置;
  • 此次交換結束后,令 k = parent,繼續以同樣的方法操作,直到 k <= 0 時(到達根結點)結束;

回到目錄

四. 索引

indexOf(o)是個私有方法,但好多公開方法中都調用了它,比如:remove(o),contains(o)等,所以在這里也簡單提一下。該算法並不復雜。時間復雜度為O(n),n = size。

1 private int indexOf(Object o) {
2     if (o != null) {
3         for (int i = 0; i < size; i++)
4             if (o.equals(queue[i]))
5                 return i;
6     }
7     return -1;
8 }
View Code

indexOf(o)中比較兩個元素是否相等,使用的是equals(),而接下來要提的removeEq(o)中直接使用了 == 來判斷,請讀者注意區別。

回到目錄

五. 刪除元素

remove(o)、removeEq(o),二者只是在判斷兩個元素是否相等時使用的方法不同(前者使用equals(),后者使用==),其他內容相同,它們都調用了removeAt(i)來執行刪除操作。刪除元素后很可能會破壞堆的性質,所以同樣需要進行維護。刪除元素的維護要比添加元素的維護稍微復雜一點,因為可能同時涉及了:向上調整siftUp和向下調整siftDown。源碼如下:

 1 public boolean remove(Object o) {
 2     int i = indexOf(o);
 3     if (i == -1)
 4         return false;
 5     else {
 6         removeAt(i);
 7         return true;
 8     }
 9 }
10 
11 boolean removeEq(Object o) {
12     for (int i = 0; i < size; i++) {
13         if (o == queue[i]) {
14             removeAt(i);
15             return true;
16         }
17     }
18     return false;
19 }
20 
21 @SuppressWarnings("unchecked")
22 E removeAt(int i) {
23     // assert i >= 0 && i < size;
24     modCount++;
25     int s = --size;
26     if (s == i) // removed last element
27         queue[i] = null;
28     else {
29         E moved = (E) queue[s];
30         queue[s] = null;
31         siftDown(i, moved);
32         if (queue[i] == moved) {
33             siftUp(i, moved);
34             if (queue[i] != moved)
35                 return moved;
36         }
37     }
38     return null;
39 }
View Code

我們還是通過例子來學習吧,通過對 pq = {0, 1, 7, 2, 3, 8, 9, 4, 5, 6} 進行一系列刪除操作,來理解算法的運作過程。算法時間復雜度O(lgn),n = size。

  • 第1步,remove(6),indexOf(6) = 9,removeAt(9)(用r(9)表示,后面同理),由於i = 9為隊列末端,刪除后不會破壞堆性質,所以可以直接刪除;
  • 第2步,remove(1),即r(1),根據圖(5.b)可以看出,算法是拿隊列尾部pq[8]去替換pq[1],替換后破壞了最小堆的性質,需要向下調整進行維護;
  • 第3步,remove(8),即r(5),使用隊列尾部元素pq[7]替換pq[5],替換后破壞了最小堆的性質,需要向上調整進行維護;

回到目錄

六. 取堆頂

peek()可以在O(1)的時間復雜度下取到堆頂元素pq[0],看源碼一目了然:

1 @SuppressWarnings("unchecked")
2 public E peek() {
3     return (size == 0) ? null : (E) queue[0];
4 }
View Code

回到目錄

七. 刪除堆頂

刪除堆頂使用poll()方法,其算法思想等價於removeAt(0)(時間復雜度O(lgn)),稍微有點區別的是,其只涉及到向下調整,不涉及向上調整。不清楚的朋友可以參看(五. 刪除元素),下面是源碼:

 1 @SuppressWarnings("unchecked")
 2 public E poll() {
 3     if (size == 0)
 4         return null;
 5     int s = --size;
 6     modCount++;
 7     E result = (E) queue[0];
 8     E x = (E) queue[s];
 9     queue[s] = null;
10     if (s != 0)
11         siftDown(0, x);
12     return result;
13 }
View Code

回到目錄

八. 清除隊列

清除隊列clear(),就是依次把pq[i]置為null,然后size置0,但是pq.length沒有改變。時間復雜度為O(n),n = size。源碼如下:

1 public void clear() {
2     modCount++;
3     for (int i = 0; i < size; i++)
4         queue[i] = null;
5     size = 0;
6 }
View Code

回到目錄

九. 遍歷

可以使用迭代器(Iterator)來遍歷pq[ ]本身,或者調用toArray()、toArray(T[] a)方法來生成一個pq[ ]的副本進行遍歷。遍歷本身的時間復雜度為O(n),n = size。

使用迭代器遍歷 pq = {0, 1, 7, 2, 3, 8, 9, 4, 5, 6},方法如下:

 1 public static void traverse1(PriorityQueue<Integer> x) {
 2     Iterator<Integer> it = x.iterator();
 3     while (it.hasNext()) {
 4         System.out.print(it.next() + " ");
 5     }
 6     System.out.println();
 7 }
 8 // 或者更簡單的,結合java語法糖,可以寫成如下形式
 9 public static void traverse2(PriorityQueue<Integer> x) {
10     for (int a : x) {
11         System.out.print(a + " ");
12     }
13     System.out.println();
14 }
15 /* 輸出
16 0 1 7 2 3 8 9 4 5 6 
17 */
View Code

通過拷貝pq[ ]副本來遍歷,方法如下:

 1 public static void traverse3(PriorityQueue<Integer> x) {
 2     Object[] ins = x.toArray();
 3     for (Object a : ins) {
 4         System.out.print((Integer)a + " ");
 5     }
 6     System.out.println();
 7 }
 8 
 9 public static void traverse4(PriorityQueue<Integer> x) {
10     Integer[] ins = new Integer[100];
11     ins = x.toArray(ins);
12     for (int i = 0, len = x.size(); i < len; i++) {
13         System.out.print(ins[i] + " ");
14     }
15     System.out.println();
16 }
17 /* 輸出
18 0 1 7 2 3 8 9 4 5 6 
19 */
View Code

在使用toArray(T[] a)拷貝來進行遍歷時,需要注意(x表示PriorityQueue對象):

  • 如果ins[ ]的容量大於x.size(),請使用for (int i = 0; i < x.size(); i++) 來遍歷,否則可能會獲取到多余的數據;或者你使用for (int a : ins)來遍歷時,可能導致NullPointerException異常;
  • 請使用 ins = x.toArray(ins) 的寫法來確保正確獲取到pq[ ]副本。當ins[ ]容量大於x.size()時,寫為 x.toArray(ins) 能正確獲取到副本,但當ins[ ]容量小於x.size()時,該寫法就無法正確獲取副本。因為此情況下toArray(T[] a)內部會重新生成一個大小為x.size()的Integer數組進行拷貝,然后return該數組;

toArray(T[] a)源碼如下:

 1 @SuppressWarnings("unchecked")
 2 public <T> T[] toArray(T[] a) {
 3     final int size = this.size;
 4     if (a.length < size)
 5         // Make a new array of a's runtime type, but my contents:
 6         return (T[]) Arrays.copyOf(queue, size, a.getClass());
 7     System.arraycopy(queue, 0, a, 0, size);
 8     if (a.length > size)
 9         a[size] = null;
10     return a;
11 }
View Code

回到目錄

十. 小測試

下面來說說文章開頭我提到的那個題目吧,如下(點擊這里在線做題)(請使用PriorityQueue來完成):

/* 數據流中的中位數
題目描述
如何得到一個數據流中的中位數?如果從數據流中讀出奇數個數值,那么中位數就是所有數值排序之后位於中間的數值。
如果從數據流中讀出偶數個數值,那么中位數就是所有數值排序之后中間兩個數的平均值。我們使用Insert()方法讀取數據流,
使用GetMedian()方法獲取當前讀取數據的中位數。
*/

public class Solution {
    public void Insert(Integer num) {}
    public Double GetMedian() {}
}

我寫的參考代碼(帶解析),如下:

 1 /*
 2 關鍵點:
 3  大根堆maxq       小根堆minq
 4 ----------      -------------
 5           \    /
 6   <= A     A  B   >= B
 7           /    \
 8 ----------      -------------
 9  
10 每次insert(num)前要確保 :
11     1) maxq.size == q.size // 偶數個時,二者元素個數相等
12 或  2) minq.size == maxq.size + 1 // 奇數個時把多余的1個放到小根堆minq
13 這樣一來,獲取中位數時:
14 奇數個:minq.top;
15 偶數個:(minq.top + maxq.top) / 2
16  
17 每次isnert(num)后,可能會打破上面的條件,出現下面的情況:
18     1) maxq.size == q.size + 1 // 打破條件(1) => 這時需要把maxq.top放到minq中
19 或  2) minq.size == maxq.size + 2 // 打破條件(2) => 這時需要把minq.top放到maxq中
20 */
21 
22 import java.util.Comparator;
23 import java.util.PriorityQueue;
24  
25 public class JZOffer_63_Solution_02 {
26     PriorityQueue<Integer> minq = new PriorityQueue<Integer>();
27     PriorityQueue<Integer> maxq = new PriorityQueue<Integer>((o1, o2) -> o2.compareTo(o1));
28 
29     public void Insert(Integer num) {
30         if (minq.isEmpty() || num >= minq.peek()) minq.offer(num);
31         else maxq.offer(num);
32         if (minq.size() == maxq.size()+2) maxq.offer(minq.poll());
33         if (maxq.size() == minq.size()+1) minq.offer(maxq.poll());
34     }
35 
36     public Double GetMedian() {
37         return minq.size() == maxq.size() ? (double)(minq.peek()+maxq.peek())/2.0 : (double)minq.peek();
38     }
39 }
View Code

回到目錄

 轉載請說明出處,have a good time! :D


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM