Caffe BatchNormalization 推導
總所周知,BatchNormalization通過對數據分布進行歸一化處理,從而使得網絡的訓練能夠快速並簡單,在一定程度上還能防止網絡的過擬合,通過仔細看過Caffe的源碼實現后發現,Caffe是通過BN層和Scale層來完整的實現整個過程的。
談談理論與公式推導
那么再開始前,先進行必要的公式說明:定義\(L\)為網絡的損失函數,BN層的輸出為\(y\),根據反向傳播目前已知 \(\frac{\partial L}{\partial y_i}\),其中:
推導的過程中應用了鏈式法則:
則只需要着重討論公式 \(\frac{\partial y_j}{\partial x_i}\)
分布探討:
(1) \(\overline x\)對\(x_i\)的導函數
(2) \(\delta^2\)對\(x_i\)的導函數
由於 \(\sum_{j=1}^{m}2*(x_j-\overline x) = 2* \sum_{i=1}^{m}x_i - n*\overline x = 0\)
所以: \(\frac{\partial \delta^2}{\partial x_i} = \frac{2}{m}*(x_i-\overline x)\)
具體推導:
此處當\(j\)等於\(i\)成立時時,分子求導多一個 \(x_i\)的導數
根據上式子,我們代入鏈式法則的式子
我們提出 \((\delta^2+\epsilon)^{-1/2}:\)
至此,我們可以對應到caffe的具體實現部分
// if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
//
// dE(Y)/dX =
// (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y)
// ./ sqrt(var(X) + eps)
//
// where \cdot and ./ are hadamard product and elementwise division,
談談具體的源碼實現
知道了BN層的公式與原理,接下來就是具體的源碼解析,由於考慮到的情況比較多,所以\(Caffe\)中的BN的代碼實際上不是那么的好理解,需要理解,BN的歸一化是如何歸一化的:
HW的歸一化,求出NC個均值與方差,然后N個均值與方差求出一個均值與方差的Vector,size為C,即相同通道的一個mini_batch的樣本求出一個mean和variance
成員變量
BN層的成員變量比較多,由於在bn的實現中,需要記錄mean_,variance_,歸一化的值,同時根據訓練和測試實現也有所差異。
Blob<Dtype> mean_,variance_,temp_,x_norm; //temp_保存(x-mean_x)^2
bool use_global_stats_;//標注訓練與測試階段
Dtype moving_average_fraction_;
int channels_;
Dtype eps_; // 防止分母為0
// 中間變量,理解了BN的具體過程即可明了為什么需要這些
Blob<Dtype> batch_sum_multiplier_; // 長度為N*1,全為1,用以求和
Blob<Dtype> num_by_chans_; // 臨時保存H*W的結果,length為N*C
Blob<Dtype> spatial_sum_multiplier_; // 統計HW的均值方差使用
成員函數
成員函數主要也是LayerSetUp,Reshape,Forward和Backward,下面是具體的實現:
LayerSetUp,層次的建立,相應數據的讀取
//LayerSetUp函數的具體實現
template <typename Dtype>
void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top){
// 參見proto中添加的BatchNormLayer
BathcNormParameter param = this->layer_param_.batch_norm_param();
moving_average_fraction_ = param.moving_average_fraction();//默認0.99
//這里有點多余,好處是防止在測試的時候忘寫了use_global_stats時默認true
use_global_stats_ = this->phase_ == TEST;
if (param.has_use_global_stat()) {
use_global_stats_ = param.use_global_stats();
}
if (bottom[0]->num_axes() == 1) { //這里基本看不到為什么.....???
channels_ = 1;
}
else{ // 基本走下面的通道,因為輸入是NCHW
channels_ = bottom[0]->shape(1);
}
eps_ = param.eps(); // 默認1e-5
if (this->blobs_.size() > 0) { // 測試的時候有值了,保存了均值方差和系數
//保存mean,variance,
}
else{
// BN層的內部參數的初始化
this->blobs_.resize(3); // 均值滑動,方差滑動,滑動系數
vector<int>sz;
sz.push_back(channels_);
this->blobs_[0].reset(new Blob<Dtype>(sz)); // C
this->blobs_[1].reset(new Blob<Dtype>(sz)); // C
sz[0] = 1;
this->blobs_[2].reset(new Blob<Dtype>(sz)); // 1
for (size_t i = 0; i < 3; i++) {
caffe_set(this->blobs_[i]->count(),Dtype(0),
this->blobs_[i]->mutable_cpu_data());
}
}
}
Reshape,根據BN層在網絡的位置,調整bottom和top的shape
Reshape層主要是完成中間變量的值,由於是按照通道求取均值和方差,而CaffeBlob是NCHW,因此先求取了HW,后根據BatchN求最后的輸出C,因此有了中間的batch_sum_multiplier_和spatial_sum_multiplier_以及num_by_chans_其中num_by_chans_與前兩者不想同,前兩者為方便計算,初始為1,而num_by_chans_為中間過渡
template <typename Dtype>
void BatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
if (bottom[0]->num_axes() >= 1) {
CHECK_EQ(bottom[0]->shape(1),channels_);
}
top[0]->ReshapeLike(*bottom[0]); // Reshape(bottom[0]->shape());
vector<int>sz;
sz.push_back(channels_);
mean_.Reshape(sz);
variance_.Reshape(sz);
temp_.ReshapeLike(*bottom[0]);
x_norm_.ReshapeLike(*bottom[0]);
sz[0] = bottom[0]->shape(0); //N
// 后續會初始化為1,為求Nbatch的均值和方差
batch_sum_multiplier_.Reshape(sz);
caffe_set(batch_sum_multiplier_.count(),Dtype(1),
batch_sum_multiplier_.mutable_cpu_data());
int spatial_dim = bottom[0]->count(2);//H*W
if (spatial_sum_multiplier_.num_axes() == 0 ||
spatial_sum_multiplier_.shape(0) != spatial_dim) {
sz[0] = spatial_dim;
spatial_sum_multiplier_.Reshape(sz); //初始化1,方便求和
caffe_set(spatial_sum_multiplier_.count(),Dtype(1)
spatial_sum_multiplier_.mutable_cpu_data());
}
// N*C,保存H*W后的結果,會在計算中結合data與spatial_dim求出
int numbychans = channels_*bottom[0]->shape(0);
if (num_by_chans_.num_axes() == 0 ||
num_by_chans_.shape(0) != numbychans) {
sz[0] = numbychans;
num_by_chans_.Reshape(sz);
}
}
Forward 前向計算
前向計算,根據公式完成前計算,x_norm與top相同,均為歸一化的值
template <typename Dtype>
void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// 想要完成前向計算,必須計算相應的均值與方差,此處的均值與方差均為向量的形式c
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
int num = bottom[0]->shape(0);// N
int spatial_dim = bottom[0]->count(2); //H*W
if (bottom[0] != top[0]) {
caffe_copy(top[0]->count(),bottom_data,top_data);//先復制一下
}
if (use_global_stats_) { // 測試階段,使用全局的均值
const Dtype scale_factory = this_->blobs_[2]->cpu_data()[0] == 0?
0:1/this->blobs_[2]->cpu_data()[0];
// 直接載入訓練的數據 alpha*x = y
caffe_cpu_scale(mean_.count(),scale_factory,
this_blobs_[0]->cpu_data(),mean_.mutable_cpu_data());
caffe_cpu_scale(variance_.count(),scale_factory,
this_blobs_[1]->cpu_data(),variance_.mutable_cpu_data());
}
else{ //訓練階段 compute mean
//1.計算均值,先計算HW的,在包含N
// caffe_cpu_gemv 實現 y = alpha*A*x+beta*y;
// 輸出的是channels_*num,
//每次處理的列是spatial_dim,由於spatial_sum_multiplier_初始為1,即NCHW中的
// H*W各自相加,得到N*C*average,此處多除以了num,下一步可以不除以
caffe_cpu_gemv<Dtype>(CBlasNoTrans,channels_*num,spatial_dim,
1./(spatial_dim*num),bottom_data,spatial_sum_multiplier_.cpu_data()
,0.,num_by_chans_.mutable_cpu_data());
//2.計算均值,計算N各的平均值.
// 由於輸出的是channels個均值,因此需要轉置
// 上一步得到的N*C的均值,再按照num求均值,因為batch_sum全部為1,
caffe_cpu_gemv<Dtype>(CBlasTrans,num,channels_,1,
num_by_chans_.cpu_data(),batch_sum_multiplier_.cpu_data(),
0,mean_.mutable_cpu_data());
}
// 此處的均值已經保存在mean_中了
// 進行 x - mean_x 操作,需要注意按照通道,即先確定x屬於哪個通道.
// 因此也是進行兩種,先進行H*W的減少均值
// caffe_cpu_gemm 實現alpha * A*B + beta* C
// 輸入是num*1 * 1* channels_,輸出是num*channels_
caffe_cpu_gemm<Dtype>(CBlasNoTrans,CBlasNoTrans,num,channels_,1,1,
batch_sum_multiplier_.cpu_data(),mean_.cpu_data(),0,
num_by_chans_.mutable_cpu_data());
//同上,輸入是num*channels_*1 * 1* spatial = NCHW
// top_data = top_data - mean;
caffe_cpu_gemm<Dtype>(CBlasNoTrans,CBlasNoTrans,num*channels_,
spatial_dim,1,-1,num_by_chans_.cpu_data(),
spatial_sum_multiplier_.cpu_data(),1, top_data());
// 解決完均值問題,接下來就是解決方差問題
if (use_global_stats_) { // 測試的方差上述已經讀取了
// compute variance using var(X) = E((X-EX)^2)
// 此處的top已經為x-mean_x了
caffe_powx(top[0]->count(),top_data,Dtype(2),
temp_.mutable_cpu_data());//temp_保存(x-mean_x)^2
// 同均值一樣,此處先計算spatial_dim的值
caffe_cpu_gemv<Dtype>(CblasNoTrans,num*channels_,spatial_dim,
1./(num*spatial_dim),temp_.cpu_data(),
spatial_sum_multiplier_.cpu_data(),0,
num_by_chans_.mutable_cpu_data();
)
caffe_cpu_gemv<Dtype>(CBlasTrans,num,channels_,1.,
num_by_chans_.cpu_data(),batch_sum_multiplier_.cpu_data(),
0,variance_.mutable_cpu_data());// E((X_EX)^2)
//均值和方差計算完成后,需要更新batch的滑動系數
this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
this->blobs_[2]->mutable_cpu_data()[0] += 1;
caffe_cpu_axpby(mean_.count(),Dtype(1),mean_.cpu_data(),
moving_average_fraction_,this->blobs_[0]->mutable_cpu_data());
int m = bottom[0]->count()/channels_;
Dtype bias_correction_factor = m > 1? Dtype(m)/(m-1):1;
caffe_cpu_axpby(variance_.count(),bias_correction_factor,
variance_.cpu_data(),moving_average_fraction_,
this->blobs_[1]->mutable_cpu_data());
}
// 方差求個根號,加上eps為防止分母為0
caffe_add_scalar(variance_.count(),eps_,variance_.mutable_cpu_data());
caffe_powx(variance_.count(),variance_.cpu_data(),Dtype(0.5),
variance_.mutable_cpu_data());
// top_data = x-mean_x/sqrt(variance_),此處的top_data已經轉化為x-mean_x了
// 同減均值,也要分C--N*C和 N*C --- N*C*H*W
// N*1 * 1*C == N*C
caffe_cpu_gemm<Dtype>(CBlasNoTrans,CBlasNoTrans,num,channels_,1,1,
batch_sum_multiplier_.cpu_data(),variance_.cpu_data(),0,
num_by_chans_.mutable_cpu_data());
// NC*1 * 1* spatial_dim = NCHW
caffe_cpu_gemm<Dtype>(CBlasNoTrans,CBlasNoTrans,num*channels_,spatial_dim,
1, 1.,num_by_chans_.cpu_data(),spatial_sum_multiplier_.cpu_data(), 0,
temp_.mutable_cpu_data());
// temp最終保存的是sqrt(方差+eps)
caffe_cpu_div(top[0].count(),top_data,temp_.cpu_data(),top_data);
}
整個forward過程按照x-mean/variance的過程進行,包含了求mean和variance,他們都是C*1的向量,然后輸入的是NCHW,因此通過了gemm操作做廣播填充到整個featuremap然后完成減mean和除以方差的操作。同時需要注意caffe的inplace操作,所以用x_norm保存原始的top值,后續修改也不會影響它。
Backward過程,根據梯度,反向計算
Backward過程會根據前面所推導的公式進行計算,具體的實現如下面所示.
template <typename Dtype>
void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff;
if (bottom[0] != top[0]) { // 判斷是否同名
top_diff = top[0]->cpu_diff();
}
else{
caffe_copy(x_norm_.count(),top[0]->cpu_diff(),x_norm_.mutable_cpu_diff());
top_diff = x_norm_.cpu_diff();
}
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
if (use_global_stats_) { // 測試階段
caffe_div(temp_.count(),top_diff,temp_.cpu_data(),bottom_diff);
return ; // 測試階段不需要計算梯度。
}
const Dtype* top_data = x_norm_.cpu_data();
int num = bottom[0]->shape(0); //n
int spatial_dim = bottom[0]->count(2); // H*W
// 根據推導的公式開始具體計算。
// dE(Y)/dX =
// (top_diff- mean(top_diff) - mean(top_diff \cdot Y) \cdot Y)
// ./ sqrt(var(X) + eps)
// sum(top_diff \cdot Y) ,y為x_norm_ NCHW,求取的均先求C通道的均值
caffe_mul(temp_.count(),top_data,top_diff,bottom_diff);
//NC*HW* HW*1 = NC*1
caffe_cpu_gemv<Dtype>(CblasNoTrans,channels_*num,spatial_dim,1.,
bottom_diff,spatial_sum_multiplier_.cpu_data(),0,
num_by_chans_.mutable_cpu_data());
// (NC)^T*1 * N*1 = C*1
caffe_cpu_gemv<Dtype>(CBlasTrans,num,channels_,1.,
num_by_chans_.cpu_data(),batch_sum_multiplier_.cpu_data(),
0,mean_.mutable_cpu_data());
//reshape broadcast
// N*1 * 1* C = N* C
caffe_cpu_gemm<Dtype>(CblasNoTrans,CblasNoTrans,num,channels_,1,1,
batch_sum_multiplier_.cpu_data(),mean_.cpu_data(),0,
num_by_chans_.mutable_cpu_data());
// N*C *1 * 1* HW = NC* HW
caffe_cpu_gemm<Dtype>(CblasNoTrans,CblasNoTrans,num*channels_,spatial_dim,
1,1.,num_by_chans_.cpu_data(),spatial_sum_multiplier_.cpu_data(),0,
bottom_diff);
//相當與 sum (DE/DY .\cdot Y)
// sum(dE/dY \cdot Y) \cdot Y
caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff);
// 完成了右邊一個部分,還有前面的 sum(DE/DY)和DE/DY
// 再完成sum(DE/DY)
caffe_cpu_gemv<Dtype>(CblasNoTrans,channels_*num,spatial_dim,1,
top_diff,spatial_sum_multiplier_.cpu_data(),0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemv<Dtype>(CBlasTrans,num,channels_,1.,
num_by_chans_.cpu_data(),batch_sum_multiplier_.cpu_data(),0,
mean_.mutable_cpu_data());
//reshape broadcast
caffe_cpu_gemm<Dtype>(CblasNoTrans,CblasNoTrans,num,channels_,1,
1,batch_sum_multiplier_.cpu_data(),mean_.cpu_data(),0,
num_by_chans_.mutable_cpu_data());
// 現在完成了sum(DE/DY)+y*sum(DE/DY.\cdot y)
caffe_cpu_gemm<Dtype>(CblasNoTrans,CblasNoTrans,num*channels_,spatial_dim,
1,1.,num_by_chans_.cpu_data(),spatial_sum_multiplier_.cpu_data(),1,
bottom_diff);
//top_diff - 1/m * (sum(DE/DY)+y*sum(DE/DY.\cdot y))
caffe_cpu_axpby(bottom[0]->count(),Dtype(1),top_diff,
Dtype(-1/(num*spatial_dim)),bottom_diff);
// 前面還有常數項 variance_+eps
caffe_div(temp_.count(),bottom_diff,temp_.cpu_data(),bottom_diff);
}
backward的過程也是先求出通道的值,然后廣播到整個feature_map,來回兩次,然后調用axpby完成 top_diff - 1/m* (sum(top_diff)+ysum(top_diffy)))這里的y針對通道進行。
本文作者: 張峰
本文鏈接:http://www.enjoyai.site/2017/11/06/
版權聲明:本博客所有文章,均采用CC BY-NC-SA 3.0 許可協議。轉載請注明出處!