快速排序中的分割算法的解析與應用


一,分割(partition)算法介紹

所謂分割算法,先選定一個樞軸元素,然后 將數組中的元素分成兩部分:比樞軸元素小的部分都位於樞軸元素左邊;比樞軸元素大的部分都位於樞軸元素右邊

此時,樞軸元素在數組中的位置就被“永久地確定”下來了---將整個數組排序,該樞軸元素的位置不會變化。

另外,樞軸元素的選取對分割算法至關重要。一般而言,終極追求的是:將數組平分。因此,盡可能地讓樞軸元素的選取隨機化和靠近中位數。

這里采用“三數取中”法選取樞軸元素。

關於快速排序排序算法,可參考:排序算法總結之快速排序

 

二,分割算法的實現

 1 //分割數組,將數組分成兩部分. 一部分比pivot(樞軸元素)大,另一部分比pivot小
 2     private static int parition(int[] arr, int left, int right){
 3         
 4         int pivot = media3(arr, left, right);
 5         int i = left;
 6         int j = right - 1;//注意 ,在 media3()中 arr[right-1]就是 pivot
 7         
 8         for(;;)
 9         {
10             while(arr[++i] < pivot){}
11             while(arr[--j] > pivot){}
12             if(i < j)
13                 swap(arr, i, j);
14             else
15                 break;
16         }
17         
18         swap(arr, i, right-1);//restore pivot, 將樞軸元素放置到合適位置:arr左邊元素都比pivot小,右邊都比pivot大
19         return i;// 返回 pivot的 索引
20     }

①第4行,樞軸元素是通過“三數取中”法選擇的。在“三數取中”時,還做了一些優化:將 樞軸元素 放到 數組末尾的倒數第二個位置處。具體參考 media3()
需要注意的是:當輸入的數組中長度為1 或者 2 時, partition會出現向下越界(但對快排而言,當數組長度很小的,其實可以不用 partition,而是直接用插入排序)。因此,可加入以下的修改。

 1 //分割數組,將數組分成兩部分. 一部分比pivot(樞軸元素)大,另一部分比pivot小
 2     private static int parition(int[] arr, int left, int right){
 3         
 4         int pivot = media3(arr, left, right);
 5         int i = left;
 6         int j = right - 1;//注意 ,在 media3()中 arr[right-1]就是 pivot
 7         
 8         //應對特殊情況下的數組,比如數組長度 小於3
 9         if(i >= j)
10             return i;
11         
12         for(;;)
13         {
14             while(arr[++i] < pivot){}
15             while(arr[--j] > pivot){}
16             if(i < j)
17                 swap(arr, i, j);
18             else
19                 break;
20         }
21         
22         swap(arr, i, right-1);//restore pivot 將樞軸元素放置到合適位置:arr左邊元素都比pivot小,右邊都比pivot大
23         return i;// 返回 pivot的 索引
24     }

 

再來看看,三數取中算法,這里也有個特殊情況:當數組中元素個數都沒有3個時....怎么辦?

 1     //三數取中,用在快排中隨機選擇樞軸元素時
 2     private static int media3(int[] arr, int left, int right){
 3         if(arr.length == 1)
 4             return arr[0];
 5         
 6         if(left == right)
 7             return arr[left];
 8         
 9         int center = (left + right) / 2;
10         
11         //找出三個數中的最小值放到 arr[left]
12         if(arr[center] < arr[left])
13             swap(arr, left, center);
14         if(arr[right] < arr[left])
15             swap(arr, left, right);
16         
17         //將 中間那個數放到 arr[media]
18         if(arr[center] > arr[right])
19             swap(arr, center, right);
20         
21         swap(arr, center, right-1);//盡量將大的元素放到右邊--將privot放到右邊, 可簡化 分割操作(partition).
22         return arr[right-1];//返回中間大小的那個數
23     }

其實,這里的“三數取中”的實現,與參考資料中提到的三數取中實現有一點不同。這是正常的,畢竟實現細節不同。如果有錯誤,需要自行調試。

這里提下第3-7行的兩個if語句:當需要 “取中”的目標數組長度為1時,或者說 對數組中某些范圍內[left, right]的元素進行“取中”時,若left=right,則根本就沒有3個數,違背了“三數取中”的本意(隨機地選取樞軸元素),故直接 return。

當數組中元素只有一個時,第18行會越界。為了防止這種情況,在第3-4行就先對數組長度進行判斷。當數組中只有兩個元素,其實就相當於 center=left,因此,程序也沒問題。

 

三,分割算法的應用---O(N)時間復雜度找出無序數組中第k大的元素

給定一個數組,數組中某個元素出現的次數超過了數組大小的一半,找出這個元素。

比如輸入:[2,5,4,4,5,5,5,6,5] ,輸出 5

這個問題,其實可以轉化成求解中位數問題。因為,當數組有序時,出現次數超過一半的那個元素一定位於數組的中間。

所謂中位數,就是 假設 數組是有序的情況下,中間那個元素。即 arr[arr.length/2]

而要求解中位數,當然可以先對數組進行排序,但排序的時間復雜度為O(NlogN),那有沒有更快的算法?

當然是有的。就是借助partition分割算法 來 實現。

 1 //找出 arr 中 第  n/2  大的那個元素
 2     public static int media_number(int[] arr){
 3         int left = 0;
 4         int right = arr.length - 1;
 5         int center = (left + right) / 2;
 6         
 7         int pivot_index = parition(arr, left, right);//樞軸元素的數組下標
 8         
 9         while(pivot_index != center)
10         {
11             if(pivot_index > center){
12                 right = pivot_index - 1;
13                 pivot_index = parition(arr, left, right);
14             }
15             else{
16                 left = pivot_index + 1;
17                 pivot_index = parition(arr, left, right);
18             }
19         }
20         return arr[center];
21     }

上面算法不僅可以求解“找出超過一半的數字”,也可以求解任何一個數組的中位數。

這里遞歸表達式 T(N)=T(N/2)+O(N),O(N)表示將數組 分成兩部分所花的代價。

故時間復雜度為O(N)

 

四,參考資料

排序算法總結之快速排序

 整個完整代碼

public class Middle_Large {
    
    //找出 arr 中 第  n/2  大的那個元素
    public static int media_number(int[] arr){
        int left = 0;
        int right = arr.length - 1;
        int center = (left + right) / 2;
        
        int pivot_index = parition(arr, left, right);
        
        while(pivot_index != center)
        {
            if(pivot_index > center){
                right = pivot_index - 1;
                pivot_index = parition(arr, left, right);
            }
            else{
                left = pivot_index + 1;
                pivot_index = parition(arr, left, right);
            }
        }
        return arr[center];
    }
    
    //分割數組,將數組分成兩部分. 一部分比pivot(樞軸元素)大,另一部分比pivot小
    private static int parition(int[] arr, int left, int right){
        
        int pivot = media3(arr, left, right);
        int i = left;
        int j = right - 1;//注意 ,在 media3()中 arr[right-1]就是 pivot
        
        //應對特殊情況下的數組,比如數組長度 小於3
        if(i >= j)
            return i;
        
        for(;;)
        {
            while(arr[++i] < pivot){}
            while(arr[--j] > pivot){}
            if(i < j)
                swap(arr, i, j);
            else
                break;
        }
        
        swap(arr, i, right-1);//restore pivot 將樞軸元素放置到合適位置:arr左邊元素都比pivot小,右邊都比pivot大
        return i;// 返回 pivot的 索引
    }
    
    
    //三數取中,用在快排中隨機選擇樞軸元素時
    private static int media3(int[] arr, int left, int right){
        if(arr.length == 1)
            return arr[0];
        
     if(left == right)
return arr[left];
int center = (left + right) / 2; //找出三個數中的最小值放到 arr[left] if(arr[center] < arr[left]) swap(arr, left, center); if(arr[right] < arr[left]) swap(arr, left, right); //將 中間那個數放到 arr[media] if(arr[center] > arr[right]) swap(arr, center, right); swap(arr, center, right-1);//盡量將大的元素放到右邊--將privot放到右邊, 可簡化 分割操作(partition). return arr[right-1];//返回中間大小的那個數 } private static void swap(int[] arr, int left, int right){ int tmp = arr[left]; arr[left] = arr[right]; arr[right] = tmp; } public static void main(String[] args) { int[] arr = {5,6,8,4,1,5,5,5,5}; int result = media_number(arr); System.out.println(result); } }

 

 

另外,再寫了一個尋找第K(K從1開始)大元素的程序:

public class FindKLargest {

    public static <T extends Comparable<? super T>> T findK(T[] arr, int k) {
        k = k - 1;

        if (arr == null || arr.length == 0) {
            throw new IllegalArgumentException("array is null");
        }

        if (k < 0) {
            throw new IllegalArgumentException("k must be > 0");
        }

        if (k > arr.length - 1) {
            k = arr.length - 1;
        }

        int low = 0;
        int high = arr.length - 1;
        int pivot_index = partition(arr, low, high);
        while (pivot_index != k) {
            if (pivot_index > k) {
                high = pivot_index - 1;
                pivot_index = partition(arr, low, high);
            }else {
                low = pivot_index + 1;
                pivot_index = partition(arr, low, high);
            }
        }
        return arr[pivot_index];
    }

    public static <T extends Comparable<? super T>> int partition(T[] arr, int low, int high) {
        T pivot = pick(arr, low, high);
        int i = low;
        int j = high;
        for (; ; ) {
            while (arr[i++].compareTo(pivot) == -1) {
            }
            while (arr[j--].compareTo(pivot) == 1) {
            }
            if (i < j) {
                swap(arr, i, j);
            } else {
                break;
            }
        }
        return i - 1;
    }


    private static <T extends Comparable<? super T>> T pick(T[] arr, int low, int high) {
        return arr[(low + high) / 2];
    }

    private static <T extends Comparable<? super T>> void swap(T[] arr, int i, int j) {
        T tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }


    public static void main(String[] args) {
        String[] strings = {"abc", "bcd", "def"};
        System.out.println(findK(strings, 1));
//        System.out.println(findK(strings, 0));


        String[] strings1 = {"abc"};
        System.out.println(findK(strings1, 2));

        Long[] longs = {1L, 2L, 3L, 4L, 5L};
        System.out.println(findK(longs, 5));
        System.out.println(findK(longs, 1));
        System.out.println(findK(longs, 2));
    }
}

 

原文:https://www.cnblogs.com/hapjin/p/5587014.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM