[Java 8] (8) Lambda表達式對遞歸的優化(上) - 使用尾遞歸 .


遞歸優化

很多算法都依賴於遞歸,典型的比如分治法(Divide-and-Conquer)。但是普通的遞歸算法在處理規模較大的問題時,常常會出現StackOverflowError。處理這個問題,我們可以使用一種叫做尾調用(Tail-Call Optimization)的技術來對遞歸進行優化。同時,還可以通過暫存子問題的結果來避免對子問題的重復求解,這個優化方法叫做備忘錄(Memoization)。

本文首先對尾遞歸進行介紹,下一票文章中會對備忘錄模式進行介紹。

使用尾調用優化

當遞歸算法應用於大規模的問題時,容易出現StackOverflowError,這是因為需要求解的子問題過多,遞歸嵌套層次過深。這時,可以采用尾調用優化來避免這一問題。該技術之所以被稱為尾調用,是因為在一個遞歸方法中,最后一個語句才是遞歸調用。這一點和常規的遞歸方法不同,常規的遞歸通常發生在方法的中部,在遞歸結束返回了結果后,往往還會對該結果進行某種處理。

Java在編譯器級別並不支持尾遞歸技術。但是我們可以借助Lambda表達式來實現它。下面我們會通過在階乘算法中應用這一技術來實現遞歸的優化。以下代碼是沒有優化過的階乘遞歸算法:

public class Factorial { public static int factorialRec(final int number) { if(number == 1) return number; else return number * factorialRec(number - 1); } } 

以上的遞歸算法在處理小規模的輸入時,還能夠正常求解,但是輸入大規模的輸入后就很有可能拋出StackOverflowError:

try { System.out.println(factorialRec(20000)); } catch(StackOverflowError ex) { System.out.println(ex); } // java.lang.StackOverflowError 

出現這個問題的原因不在於遞歸本身,而在於在等待遞歸調用結束的同時,還需要保存了一個number變量。因為遞歸方法的最后一個操作是乘法操作,當求解一個子問題時(factorialRec(number - 1)),需要保存當前的number值。所以隨着問題規模的增加,子問題的數量也隨之增多,每個子問題對應着調用棧的一層,當調用棧的規模大於JVM設置的閾值時,就發生了StackOverflowError。

轉換成尾遞歸

轉換成尾遞歸的關鍵,就是要保證對自身的遞歸調用是最后一個操作。不能像上面的遞歸方法那樣:最后一個操作是乘法操作。而為了避免這一點,我們可以先進行乘法操作,將結果作為一個參數傳入到遞歸方法中。但是僅僅這樣仍然是不夠的,因為每次發生遞歸調用時還是會在調用棧中創建一個棧幀(Stack Frame)。隨着遞歸調用深度的增加,棧幀的數量也隨之增加,最終導致StackOverflowError。可以通過將遞歸調用延遲化來避免棧幀的創建,以下代碼是一個原型實現:

public static TailCall<Integer> factorialTailRec( final int factorial, final int number) { if (number == 1) return TailCalls.done(factorial); else return TailCalls.call(() -> factorialTailRec(factorial * number, number - 1)); } 

需要接受的參數factorial是初始值,而number是需要計算階乘的值。 我們可以發現,遞歸調用體現在了call方法接受的Lambda表達式中。以上代碼中的TailCall接口和TailCalls工具類目前還沒有實現。

創建TailCall函數接口

TailCall的目標是為了替代傳統遞歸中的棧幀,通過Lambda表達式來表示多個連續的遞歸調用。所以我們需要通過當前的遞歸操作得到下一個遞歸操作,這一點有些類似UnaryOperator函數接口的apply方法。同時,我們還需要方法來完成這幾個任務:

  1. 判斷遞歸是否結束了
  2. 得到最后的結果
  3. 觸發遞歸

因此,我們可以這樣設計TailCall函數接口:

@FunctionalInterface
public interface TailCall<T> { TailCall<T> apply(); default boolean isComplete() { return false; } default T result() { throw new Error("not implemented"); } default T invoke() { return Stream.iterate(this, TailCall::apply) .filter(TailCall::isComplete) .findFirst() .get() .result(); } } 

isComplete,result和invoke方法分別完成了上述提到的3個任務。只不過具體的isComplete和result還需要根據遞歸操作的性質進行覆蓋,比如對於遞歸的中間步驟,isComplete方法可以返回false,然而對於遞歸的最后一個步驟則需要返回true。對於result方法,遞歸的中間步驟可以拋出異常,而遞歸的最終步驟則需要給出結果。

invoke方法則是最重要的一個方法,它會將所有的遞歸操作通過apply方法串聯起來,通過沒有棧幀的尾調用得到最后的結果。串聯的方式利用了Stream類型提供的iterate方法,它本質上是一個無窮列表,這也從某種程度上符合了遞歸調用的特點,因為遞歸調用發生的數量雖然是有限的,但是這個數量也可以是未知的。而給這個無窮列表畫上終止符的操作就是filter和findFirst方法。因為在所有的遞歸調用中,只有最后一個遞歸調用會在isComplete中返回true,當它被調用時,也就意味着整個遞歸調用鏈的結束。最后,通過findFirst來返回這個值。

如果不熟悉Stream的iterate方法,可以參考上一篇文章,在其中對該方法的使用進行了介紹。

創建TailCalls工具類

在原型設計中,會調用TailCalls工具類的call和done方法:

  • call方法用來得到當前遞歸的下一個遞歸
  • done方法用來結束一系列的遞歸操作,得到最終的結果
public class TailCalls { public static <T> TailCall<T> call(final TailCall<T> nextCall) { return nextCall; } public static <T> TailCall<T> done(final T value) { return new TailCall<T>() { @Override public boolean isComplete() { return true; } @Override public T result() { return value; } @Override public TailCall<T> apply() { throw new Error("end of recursion"); } }; } } 

在done方法中,我們返回了一個特殊的TailCall實例,用來代表最終的結果。注意到它的apply方法被實現成被調用拋出異常,因為對於最終的遞歸結果,是沒有后續的遞歸操作的。

以上的TailCall和TailCalls雖然是為了解決階乘這一簡單的遞歸算法而設計的,但是它們無疑在任何需要尾遞歸的算法中都能夠派上用場。

使用尾遞歸函數

使用它們來解決階乘問題的代碼很簡單:

System.out.println(factorialTailRec(1, 5).invoke()); // 120 System.out.println(factorialTailRec(1, 20000).invoke()); // 0 

第一個參數代表的是初始值,第二個參數代表的是需要計算階乘的值。

但是在計算20000的階乘時得到了錯誤的結果,這是因為整型數據無法容納這么大的結果,發生了溢出。對於這種情況,可以使用BigInteger來代替Integer類型。

實際上factorialTailRec的第一個參數是沒有必要的,在一般情況下初始值都應該是1。所以我們可以做出相應地簡化:

public static int factorial(final int number) { return factorialTailRec(1, number).invoke(); } // 調用方式 System.out.println(factorial(5)); System.out.println(factorial(20000)); 

使用BigInteger代替Integer

主要就是需要定義decrement和multiple方法來幫助完成大整型數據的階乘操作:

public class BigFactorial { public static BigInteger decrement(final BigInteger number) { return number.subtract(BigInteger.ONE); } public static BigInteger multiply( final BigInteger first, final BigInteger second) { return first.multiply(second); } final static BigInteger ONE = BigInteger.ONE; final static BigInteger FIVE = new BigInteger("5"); final static BigInteger TWENTYK = new BigInteger("20000"); //... private static TailCall<BigInteger> factorialTailRec( final BigInteger factorial, final BigInteger number) { if(number.equals(BigInteger.ONE)) return done(factorial); else return call(() -> factorialTailRec(multiply(factorial, number), decrement(number))); } public static BigInteger factorial(final BigInteger number) { return factorialTailRec(BigInteger.ONE, number).invoke(); } }


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM