模型穩定性


算法工程師的責任不僅是提出算法,而是提出更穩定的算法

1. 計算的穩定性(Computational Stability)

計算穩定性特指模型運算性能的魯棒性(Robustness),我猜計算機背景的朋友肯定不會對此感到陌生。舉個簡單例子,如果我們讓整數型(int)的變量來儲存的一個浮點變量(float),那么我們會損失精度。在機器學習中,我們往往涉及了大量的計算,受限於計算機的運算精度,很多時候我們必須進行湊整(Rounding),將無理數近似到浮點數。這個過程中不可避免的造成了大量的微小誤差,隨着湊整誤差累計積少成多,最終會導致系統報錯或者模型失敗。我們一起來看看機器學習中幾種常見的的計算穩定性風險。

1.1. 下溢(Underflow)和上溢(Overflow)

顧名思義,溢出是代表內容超過了容器的極限。在機器學習當中,因為我們大量的使用概率(Probability),而概率的區間往往在0至1之間,這就導致了下溢發生的可能性大大提高。

舉個簡單的例子,我們常常需要將多個概率相乘,假設每個概率 P_i = 0.01 :

P = P_i^{10}=0.000000000000000000001

從此可以看出,僅僅需要是個1%的概率相乘就可以得到一個極小的結果。而機器學習中往往是成百上千個數字相乘,類似的情況導致計算機無法分辨0和和一個極小數之間的區別。在這種情況下,下溢可能導致模型直接失敗。

相似的,上溢也是很容易發生的狀況。試想我們需要將多個較大的數相乘,很輕易的就可以超過計算機的上限。64位計算機的數值上限並沒有大家想象中那么大:

L_{upper} = 2^{63}-1=9,223,372,036,854,775,808

因此在實際模型中,我們會避免將多個概率相乘,而轉為求其對數(Log),舉例:

ln(\prod_{i=1}^nP(x_i)) = ln(P(x_1))+ln(P(x_2))+...+ln(P(x_n))=ln(\sum_{i=1}^nP(x_i))

這樣我們就成功的將多項連乘轉化為了多項加法,避免了可能發生的溢出。而對數還有更多優美的數學的性質,例如其單調遞增性,易轉化為概率模型,凸優化性等。

1.1. 下溢(Underflow)和上溢(Overflow)

顧名思義,溢出是代表內容超過了容器的極限。在機器學習當中,因為我們大量的使用概率(Probability),而概率的區間往往在0至1之間,這就導致了下溢發生的可能性大大提高。

舉個簡單的例子,我們常常需要將多個概率相乘,假設每個概率 P_i = 0.01 :

P = P_i^{10}=0.000000000000000000001

從此可以看出,僅僅需要是個1%的概率相乘就可以得到一個極小的結果。而機器學習中往往是成百上千個數字相乘,類似的情況導致計算機無法分辨0和和一個極小數之間的區別。在這種情況下,下溢可能導致模型直接失敗。

相似的,上溢也是很容易發生的狀況。試想我們需要將多個較大的數相乘,很輕易的就可以超過計算機的上限。64位計算機的數值上限並沒有大家想象中那么大:

L_{upper} = 2^{63}-1=9,223,372,036,854,775,808

因此在實際模型中,我們會避免將多個概率相乘,而轉為求其對數(Log),舉例:

ln(\prod_{i=1}^nP(x_i)) = ln(P(x_1))+ln(P(x_2))+...+ln(P(x_n))=ln(\sum_{i=1}^nP(x_i))

這樣我們就成功的將多項連乘轉化為了多項加法,避免了可能發生的溢出。而對數還有更多優美的數學的性質,例如其單調遞增性,易轉化為概率模型,凸優化性等。

1.2. 平滑(Smoothing)與0

和下溢和上溢類似,我們常常會發現機器學習中遇到“連乘式”中某個元素為0,導致運算失去意義。以朴素貝葉斯(Naive Bayes)為例:

P(\bm{x}|y=c) = P(y=c)\prod_{i=1}^{d}P(x_i|c)

我們判別一個樣本點屬於某個分類 c 的概率為其各項特征 x_i 屬於分類 c 的概率 P(x_i|c) 之乘積,即上式。但假設只要有任何一項 P(x_i|c)=0或者 P(y=c)=0 ,那么這個乘式的乘積就會為0。然而出現0往往並不是真的因為其概率為0,而僅僅是我們的訓練數據沒有出現過。

從某種意義上來說,這也屬於一種計算上的不穩定。常見的做法是用拉普拉斯平滑(Laplace Smoothing)來修正這種計算不穩。簡單的說就是人為的給每種可能性加一個例子,使其概率不再為0。

於是某個特征取特定值在分類下的概率就會被修正為:

Lap(P(x_i|c)) = \frac{|D_{c,x_i}|+1}{|D_c|+N_i}

在這種平滑處理后,我們所有乘子的取值都不會為0。相似的做法在自然語言處理(NLP)中也常常會用到,比如N-gram模型的語言模型也往往需要平滑來進行處理,可以學習一些平滑處理歐。

1.3. 算法穩定性(Algorithmic Stability)與擾動(Perturbation)

在機器學習或統計學習模型中,我們常常需要考慮算法的穩定性,即算法對於數據擾動的魯棒性。“模型的泛化誤差由誤差(Bias)和方差(Variance)共同決定,而高方差是不穩定性的罪魁禍首”。

簡單的說就是,如果一個算法在輸入值發生微小變化時就產生了巨大的輸出變化,我們就可以說這個算法是不穩定的。此處的算法不僅僅是說機器學習算法,也代表“中間過程”所涉及的其他算法,給出幾個具體的例子:

  • 矩陣求逆(Inverting a Matrix)的過程就屬於不穩定的,我們常常會選擇避開矩陣求逆。有興趣的讀者可以進一步了解其原因。
  • 另一個有趣的例子是神經網絡中的批量學習(Batch Learning),即訓練神經網絡時不一個個例子的訓練而是批量的學習訓練數據。在選擇對應的批量尺寸(Batch Size)和相對應的學習速率(Learning Rate)時需要特別小心,錯誤的學習率和尺寸會導致不穩定的學習過程。當我們以小批量進行學習的時候,小樣本中的高方差(High Variance)導致我們學到的梯度(Gradient)很不精確,在這種情況下,應該使用小學習速率防止我們步子邁得太大!相反的,當我們的批量尺寸選的較大時,可以放心的使用較大的速率。
  • 決策樹(Decision Tree)的性質導致它也屬於一種不穩定的模型。訓練數據中的微小變化甚至可以改變決策樹的結構,以至於我們對於決策樹的可信度總是畫上一個問號。為了解決其不穩定的問題,研究人員發明了集成學習(Ensemble Learning),其中的Bagging就通過降低其方差的方法來增強其穩定性。

於是為了方便,我們歸納出一部分穩定模型。比較常見的模型有各種支持向量機(SVM)的衍生模型,這也是SVM在本世紀初大火的原因的之一)

2. 數據的穩定性(Data Stability)

嚴格意義上說,數據穩定性往往特指的是時間序列(Time Series)的穩定性。而筆者此處指的是廣義上的數據,不僅僅是時間序列。從根本上說,數據的穩定性主要取決於其Variance。

2.1. 獨立同分布(Independent Identically Distributed)與泛化能力(Generalization Ability)

一個機器學習模型的泛化能力指的是其在新樣本上的擬合能力。模型能夠獲得強泛化能力的數據保證就是其訓練數據是獨立同分布從母體分布上采樣而得。讓我們用一點點統計學的知識....

假設我們有一個母體(Population),它的分布是1到100的正整數:

\bm{D} = \{1,2,3,...,98,99,100\}

假設我們有3個從D中得到的采樣:

  • \bm{D}_1 = \{1,4,16,25,36,49,64,81\}
  • \bm{D}_2 = \{10,20,30,40,50,60,70,80,90\}
  • \bm{D}_3 = \{1,2,3,4,5,6,7,8,9\}

我們會發現第一個采樣好像都是平方數,第二個采樣都是十的倍數,而第三個采樣似乎都是小於10的連續整數。在這種采樣下,我們可以大膽的猜測學習模型無法通過學習這三個數據集而得到良好的泛化能力....因為它們並不是獨立同分布的采樣。

那么讀者會問了,那什么才算是獨立同分布的采樣,首先:

  1. 我們希望采樣的數據不是故意的挑選的,比如刻意挑出了一堆平方數
  2. 我們希望采樣的數據是從同一個分布里面挑的,而不是從幾個分布中各挑幾個...

因此如何保證我們的訓練數據足夠穩定呢?筆者有幾句看起來像廢話的建議:

  1. 訓練數據越多越好...這樣可以降低數據中的偶然性,降低Variance
  2. 確保訓練數據和母體數據及預測數據來自於一個分布。舉例,你不能用統計學家的平均智商來預測生物學家的平均智商,這不公平...至於對哪一方不公平,留給讀者思考。

因此數據的穩定性的基本前提就是獨立同分布,且數量越多越好。穩定的數據可以保證模型的經驗誤差(Empirical Risk)約等於其泛化誤差(Generalization Risk)。

2.2. 新常態: 類別不平衡

越來越多的機器學習問題都會遭遇不平衡的數據分布,此處的不平衡可以指很多事情,比如二分類問題中的正例和反例數量懸殊。但需要注意的是,如果母體的分布本身就是不平衡的,千萬不要通過采樣來使其分布平衡。這樣就違反了獨立同分布的采樣!

面對天生不平衡的數據,我們有很多做法可以進行處理,比較常見的再平衡做法包括:

  • 過采樣(Over-Sampling): 將數據量較少的的分類重復利用
  • 欠采樣(Down-Sampling):將數據量較多的分類選擇性丟棄一部分。

在類似的情況下,往往集成學習的表現非常好,這都需要歸功於集成學習可以有效的降低Variance。讀者必須注意,無論是過采樣還是欠采樣都會帶來問題,比如過采樣容易導致過擬合但欠采樣其實浪費了數據。

因此不平衡往往也帶來了穩定性問題,而究其根本還是因為過高的Variance。

3. 性能的穩定性 - “理論衛道士”

評估機器學習模型的穩定性(Stability)和評估機器學習的表現(Performance)有本質上的不同,不能簡單的通過評估准確率這種指標來說一個機器學習穩定與否。舉個最簡單的例子,假設一個模型一會兒表現特別好,一會兒比較特別差,我們敢用這個模型於實際生產中嗎?說白了,穩定性還是由於數據的方差Variance決定。

那么有小伙伴說了,我們或許可以用交叉驗證(cross-validation)來評估一個算法模型的穩定性。沒錯這是個正確的思路,但最大的問題,就是交叉驗證太慢了。不管是五折(5-fold)還是十折(10-fold)都需要較長的時間及重復運算。生命是寶貴的,1s都不能浪費!

因此我們一般通過計算學習理論(Computational Learning Theory)有時候也叫統計計算理論(Statistical Learning Theory)來對算法進行分析。介紹兩個框架供大家參考:

  • 概率近似正確框架(Probably Approximately Correct, PAC)。PAC框架主要回答了一個問題:一個學習算法是否可以在多項式函數的時間復雜度下從樣本 \bm{x} 中近似的學到一個概念,並保證誤差在一定的范圍之內。
  • 界限出錯框架(Mistake Bound Framework, MBF)。MBF從另一個角度回答了一個問題,即一個學習模型在學習到正確概念前在訓練過程中會失誤多少次?

有鑒於篇幅以及這個概念的深度和廣度,筆者會在以后的文章中以專題的形式展開。但計算學習理論為量化學習模型穩定性指出了一個方向,同時也緩和了統計學習對機器學習長久以來的偏見--機器學習缺乏理論基礎。

只打算進行實踐而不打算在機器學習領域進行研究的讀者,不必過分深究到底什么是PAC,因為其實用性是有限的,而且還會用到很多概率論的知識。

4. 小感悟

本文的目的不是列出所有的穩定性問題,也不是想讓大家杯弓蛇影,懷疑一切。筆者只是單純想借着這篇文章說明機器學習是一門交叉學科,它不僅需要你了解計算機上面的浮點精度防止溢出,還需要你了解統計中的數據采樣過程。

從這個角度出發,計算機科學出身的讀者要放寬自己的視野,還有很多其他領域與機器學習息息相關;而統計學或者數學出身的朋友也不要覺得計算機僅僅是運算工具,你們碰到的很多問題其實說白了是運算性問題。然而,在穩定之外,對於未知領域的探索,才是創新。因此放寬“穩定”的界限,不斷探尋真理的邊界。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM