論文: Accurate, Large MiniBatch SGD:Training ImageNet in 1 Hour
因為目前的 network 和 dataset 越來越大,隨之而來的是training times的不斷攀升。為了加快網絡的訓練,采用 distributed synchronous SGD , 將 SGD minibatch 划分到一個同步工作池內進行訓練。
因為 distributed 的原因,minibatch size的增大,本paper采用的linear scaling rule對learning rate同時的放大。
Linear Scaling Rule:When the minibatch size is multiplied by k, multiply the learning rate by k.
在其他 hyper parameter不變的情況下,N' = k * N, 可以達到多個minibatch size在一次實現對網絡的update。
Warmup method:
Constant warmup: 在train前幾個epoch(一般前5 epochs)時采用較小的constant learning rate,但是對於大的learning rate,constant warmup不能很好的初始化網絡。
gradual warmup: 在 training的前幾個 epoch,逐漸將learning rate由小到大的提高,讓training在開始的時候健康的收斂。
Batch Normalization:
在 distributed training情況下,每一個per-work的sample可以看成是每一個minibatch,相互之間是獨立的。所以 underlying loss function可以不變。
BN statistics不應該在all workers之間交叉計算,不僅為了減少交流,也為保持為了優化的同樣的underlying loss function。
distributed SGD 對其他hyper parameter的影響:
1, weight decay:scaling the cross-entropy loss is not equivalent to scaling the learning rate
cross-entropy loss是 sample-dependent term,而regularization是weight-dependent term。
2, Momentum correction: Apply momentum correction after changing learning rate if using
3, Gradient aggregation: Normalize the per-worker loss by total minibatch size kn, not per-worker size n.
4, Data shuffling: Use a single random shuffling of the training data (per epoch) that is divided amongst all k workers.
在同一個worker內部,多個GPU的,則使用NCCL,進行多個GPU內部的buffer統一計算。
worker之間的通信交流:
對於該distributed SGD,只有Gradient Aggregation才需要all-worker之間的通信交流,在這里使用了兩種算法 the recursive halving and doubling algorithm 和 bucket algorithm (ring algorithm)。
其中 halving/doubling algorithm有兩步驟:reduce-scatter 和 all-gather,在 reduce-scatter階段,使用兩兩servers組成pair,進行buffer的交換,如0和1,2和3, server0 發送第二半的buffer給server1,同時接受來自server1的第一半的buffer。
all-gather則使用類似樹形結構對所有的 server 進行Gradient的gather。
實驗部分:
warmup: Large minibatch sizes are challenged by optimization difficulties in early training.
if the optimization issues are addressed, there is no apparent generalization degradation observed using large minibatch training.
distributed SGD的時間消耗:
本人觀點: 給予很多 deep network一個分布式計算的理論和可行性驗證。