Caffe源碼解析1:Blob


轉載請注明出處,樓燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/

首先看到的是Blob這個類,Blob是作為Caffe中數據流通的一個基本類,網絡各層之間的數據是通過Blob來傳遞的。這里整個代碼是非常規范的,基本上條件編譯,命名空間,模板類,各種不太經常看到的關鍵字如exlicit,inline等等。
首先提一下explicit關鍵字的作用是禁止單參數構造函數的隱式轉換,具體含義谷歌即可。還有inline的作用,iniline主要是將代碼進行復制,擴充,會使代碼總量上升,好處就是可以節省調用的開銷,能提高執行效率。

1主要變量

shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
shared_ptr<SyncedMemory> shape_data_;
vector<int> shape_;
int count_;
int capacity_;

BLob只是一個基本的數據結構,因此內部的變量相對較少,首先是data_指針,指針類型是shared_ptr,屬於boost庫的一個智能指針,這一部分主要用來申請內存存儲data,data主要是正向傳播的時候用的。同理,diff_主要用來存儲偏差,update data,shape_datashape_都是存儲Blob的形狀,一個是老版本一個是新版本。count表示Blob中的元素個數,也就是個數*通道數*高度*寬度,capacity表示當前的元素個數,因為Blob可能會reshape。

2主要函數

template <typename Dtype>
class Blob {
 public:
  Blob()
       : data_(), diff_(), count_(0), capacity_(0) {}

  /// @brief Deprecated; use <code>Blob(const vector<int>& shape)</code>.
  explicit Blob(const int num, const int channels, const int height,
      const int width);
  explicit Blob(const vector<int>& shape);

  /// @brief Deprecated; use <code>Reshape(const vector<int>& shape)</code>.
  void Reshape(const int num, const int channels, const int height,
      const int width);
  

其中Blob作為一個最基礎的類,其中構造函數開辟一個內存空間來存儲數據,Reshape函數在Layer中的reshape或者forward操作中來adjust dimension。同時在改變Blob大小時,內存將會被重新分配如果內存大小不夠了,並且額外的內存將不會被釋放。對input的blob進行reshape,如果立馬調用Net::Backward是會出錯的,因為reshape之后,要么Net::forward或者Net::Reshape就會被調用來將新的input shape 傳播到高層

Blob類里面有重載很多個count()函數,主要還是為了統計Blob的容量(volume),或者是某一片(slice),從某個axis到具體某個axis的shape乘積。

inline int count(int start_axis, int end_axis)

並且Blob的Index是可以從負坐標開始讀的,這一點跟Python好像

inline int CanonicalAxisIndex(int axis_index) 

對於Blob中的4個基本變量num,channel,height,width可以直接通過shape(0),shape(1),shape(2),shape(3)來訪問。

計算offset

inline int offset(const int n, const int c = 0, const int h = 0, const int w = 0)
inline int offset(const vector<int>& indices)

offset計算的方式也支持兩種方式,一種直接指定n,c,h,w或者放到一個vector中進行計算,偏差是根據對應的n,c,h,w,返回的offset是((n * channels() + c) * height() + h) * width() + w

其實里面稍加留意可以看到有很多的

CHECK_GE
CHECK_LE
CHECK_EQ
....

等等看意思就知道了,肯定是在做比較Geater or Eqal這樣的意思。這其實是GLOG,谷歌的一個日志庫,Caffe里面用用了大量這樣的宏,看起來也比較直觀

void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,bool reshape = false);

從一個blob中copy數據 ,通過開關控制是否copy_diff,如果是False則copy data。reshape控制是否需要reshape。好我們接着往下看

inline Dtype data_at(const int n, const int c, const int h, const int w)
inline Dtype diff_at(const int n, const int c, const int h, const int w)
inline Dtype data_at(const vector<int>& index)
inline Dtype diff_at(const vector<int>& index)
inline const shared_ptr<SyncedMemory>& data()
inline const shared_ptr<SyncedMemory>& diff()

這一部分函數主要通過給定的位置訪問數據,根據位置計算與數據起始的偏差offset,在通過cpu_data*指針獲得地址。下面幾個函數都是獲得

const Dtype* cpu_data() const;
void set_cpu_data(Dtype* data);
const int* gpu_shape() const;
const Dtype* gpu_data() const;
const Dtype* cpu_diff() const;
const Dtype* gpu_diff() const;
Dtype* mutable_cpu_data();
Dtype* mutable_gpu_data();
Dtype* mutable_cpu_diff();
Dtype* mutable_gpu_diff();

可以看到這里有data和diff兩類數據,而這個diff就是我們所熟知的偏差,前者主要存儲前向傳遞的數據,而后者存儲的是反向傳播中的梯度

void Update();

看到update里面面調用了

caffe_axpy<float>(const int N, const float alpha, const float* X,float* Y)
{ cblas_saxpy(N, alpha, X, 1, Y, 1); }

這個函數在caffe的util下面的match-functions.cpp里面,主要是負責了線性代數庫的調用,實現的功能是

\[Y=alpha * X +beta*Y \]

也就是blob里面的data部分減去diff部分

void FromProto(const BlobProto& proto, bool reshape = true);
void ToProto(BlobProto* proto, bool write_diff = false) const;

這兩個函數主要是將數據序列化,存儲到BlobProto,這里說到Proto是谷歌的一個數據序列化的存儲格式,可以實現語言、平台無關、可擴展的序列化結構數據格式。Caffe里面數據的存儲都采用這一結構,這里就不深入展開,具體可以參照這篇文章,對於proto的序列化和反序列都講解的非常詳細http://www.w2bc.com/Article/34963

Dtype asum_data() const;//計算data的L1范數
Dtype asum_diff() const;//計算diff的L1范數
Dtype sumsq_data() const;//計算data的L2范數
Dtype sumsq_diff() const;//計算diff的L2范數
void scale_data(Dtype scale_factor);//將data部分乘以一個因子
void scale_diff(Dtype scale_factor);//將diff部分乘一個因子

這幾個函數是一些零散的功能,一看就懂。

void ShareData(const Blob& other);
void ShareData(const Blob& other);

這兩個函數看名字就知道了一個是共享data,一個是共享diff,具體就是將別的blob的data和響應的diff指針給這個Blob,實現數據的共享。同時需要注意的是這個操作會引起這個Blob里面的SyncedMemory被釋放,因為shared_ptr指針被用=重置的時候回調用響應的析構器。

bool ShapeEquals(const BlobProto& other);

這函數就不用說了,比較兩個Blob形狀是否相同
好了,基本上Blob的主要參數功能基本就涵蓋在里面了,以上只是我的拙見,如有紕漏,還望指出,萬分感謝。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2024 CODEPRJ.COM