本博客是針對Andrew Ng在Coursera上的machine learning課程的學習筆記。
在大數據集上進行學習(Learning with Large Data Sets)
由於機器學習系統的性能表現往往要求其算法是low biased(在訓練集上的訓練誤差小),並且在盡可能大的數據集上做訓練。
It's not who has the best algorithm that wins. It's who has the most data.
比如對於一般線性回歸或者logistic回歸,有如圖的參數迭代更新的公式,假設我們如果有一百萬個訓練樣本(m=100,000,000),則每一次更新參數都需要做一百萬次的加和,那么計算量就很大。另外,對於high bias的算法學習曲線(如下圖的靠右圖),此時增加樣本數量也不太可能減少誤差,可能合適地增加輸入特征才是更有效的。而如左邊的情況high variance,用更大的數據集是更有效的。
隨機梯度下降(Stochastic Gradient Descent)
由於之前講的梯度下降(又稱Batch Gradient Descent)在樣本量大時計算量非常大,更新速度會很慢,為了避免這個問題,我們可以采用隨機梯度下降。
具體做法是:我們不再在對參數進行更新時遍歷加總所有樣本的誤差值,而是每次迭代更新從中只隨機選取一個樣本進行誤差計算,(本質是先將數據集打亂,再逐一進行參數學習,但還是要對所有的樣本進行一次遍歷),因此保證了參數是在快速地向全局最優解靠近的(但可能由於樣本選取的隨機性,導致梯度下降的方向並不是那么穩定地向全函數最小值處行進,相當於是犧牲穩度換取速度)。
因此相對於batch GD,stochastic GD可以在很多時候實現更高效地學習。因為stochastic GD每次迭代更新參數,都只需要用到一個樣本,而遍歷整個樣本的輪數可能進行1-10次后就能獲得非常好的假設了,而batch GD是每一次更新一個參數都需要遍歷整個數據集,那么每次迭代都要用到m個樣本。
小堆梯度下降(Mini-Batch Gradient Descent)
小堆梯度下降有時候可以比隨機梯度下降速度更快!
小堆梯度下降其實就是將隨機梯度下降中只用一個樣本進行每輪迭代的做法,變成了用b個樣本進行每輪迭代(此之謂mini-batch,因為每輪迭代用到的樣本數量在1-m之間)。
而什么時候小堆GD效率會比隨機GD高呢?那就是我們有比較好的vectorization的計算實現時。
保證隨機GD的收斂與學習速率的選擇
如何知道我們的GD算法在像收斂的方向運行?
- 在Batch GD中,我們是畫損失函數和GD迭代次數的圖像。
- 在Stochastic GD中,我們可以在固定的迭代次數后,畫成本函數的在這些迭代上的平均值
下圖中藍色線代表較大的學習速率$\alpha$,紅色代表較小的學習速率(每次迭代參數更新的幅度更小)。四幅圖代表了stochastic GD的圖像可能出現的情況。
上面兩幅圖的區別是畫圖的采樣點的頻率不同,左上方的圖是畫每迭代1000次的成本函數,右上方的圖每迭代5000次的平均成本函數圖像。由於分母變大,分子相對變化較小,因此迭代次數高的函數圖像更平穩。
下面兩幅圖也是可能出現的情況。左下方的圖藍色的線是由於學習速率過大,不穩定,應該調小學習速率。而右下方的圖則是曲線上揚,說明學習速率太大,算法發散了,也需要調小學習速率。
可以讓學習速率隨着迭代次數增加而減少,以保證逐漸穩定地收斂到最優解。
在線學習(Online Learning)
在很多機器學習運行的系統中,很多學習數據是不斷在產生的,如何可以讓系統隨着數據的涌入保持學習,而不是只能獲得一次性用固定數據訓練的模型,以獲得更優的性能,就是在線學習所要解決的問題。
給定以下一個快遞服務網站的情形:
如圖所示,在線學習系統在每次有新的數據產生(一個用戶訪問了網站,並給出是否使用快遞服務的決策)時,系統都會對參數進行一次更新。
另一個例子:電商網站的產品搜索。
Map Reduce 和 數據並行化
有時候,數據集太大,如果只在一台主機上運行,可能很久才能獲得結果,於是我們希望可能用盡量少的時間獲得好的結果。
首先來看map-reduce的做法:將數據集按照我們手頭擁有的計算資源的數量,即我們可以用來運行機器學習算法的主機的數量平均分配。然后,每個主機利用分配給它那部分的樣本運行算法,去計算損失函數的偏導數值,並獲得一個屬於該主機的結果$temp^{(i)}_j$。最后,將所有主機的結果結合在一起,用來對參數做更新。
而很多機器學習算法本質就是對訓練集上進行和加總的一些計算。
而其實不是只有多台主機才能實現並行化,在一台多核的主機上也可以進行,做法和多主機的情況是類似的。而且在這種情況下,有一個好處是,不需要擔心因為網絡延遲導致的數據傳輸緩慢等問題,而這是多主機並行化可能會面對的問題。