Caffe學習系列(15):添加新層


如何在Caffe中增加一層新的Layer呢?主要分為四步:

(1)在./src/caffe/proto/caffe.proto 中增加對應layer的paramter message;

(2)在./include/caffe/***layers.hpp中增加該layer的類的聲明,***表示有common_layers.hpp,

data_layers.hpp, neuron_layers.hpp, vision_layers.hpp 和loss_layers.hpp等;

(3)在./src/caffe/layers/目錄下新建.cpp和.cu(GPU)文件,進行類實現。

(4)在./src/caffe/gtest/中增加layer的測試代碼,對所寫的layer前傳和反傳進行測試,測試還包括速度。(可省略,但建議加上)

  

  這位博主添加了一個計算梯度的網絡層,簡介明了:

  http://blog.csdn.net/shuzfan/article/details/51322976

 

  這幾位博主增加了自定義的loss層,可供參考:

 

  http://blog.csdn.net/langb2014/article/details/50489305

 

  http://blog.csdn.net/tangwei2014/article/details/46815231

 我以添加precision_recall_loss層來學習代碼,主要是precision_recall_loss_layer.cpp的實現

#include <algorithm>  
#include <cfloat>  
#include <cmath>  
#include <vector>  
#include <opencv2/opencv.hpp>  
  
#include "caffe/layer.hpp"  
#include "caffe/util/io.hpp"  
#include "caffe/util/math_functions.hpp"  
#include "caffe/vision_layers.hpp"  
  
namespace caffe {  
  
//初始化,調用父類進行相應的初始化
template <typename Dtype>  
void PrecisionRecallLossLayer<Dtype>::LayerSetUp(  
  const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {  
  LossLayer<Dtype>::LayerSetUp(bottom, top);  
}  
//進行維度變換
template <typename Dtype>  
void PrecisionRecallLossLayer<Dtype>::Reshape(  
  const vector<Blob<Dtype>*> &bottom,  
  const vector<Blob<Dtype>*> &top) {  
  //同樣先調用父類的Reshape,通過成員變量loss_來改變輸入維度
  LossLayer<Dtype>::Reshape(bottom, top);  
  loss_.Reshape(bottom[0]->num(), bottom[0]->channels(),  
                bottom[0]->height(), bottom[0]->width());  
  
  // Check the shapes of data and label  檢查兩個輸入的維度是否想等
  CHECK_EQ(bottom[0]->num(), bottom[1]->num())  
      << "The number of num of data and label should be same.";  
  CHECK_EQ(bottom[0]->channels(), bottom[1]->channels())  
      << "The number of channels of data and label should be same.";  
  CHECK_EQ(bottom[0]->height(), bottom[1]->height())  
      << "The heights of data and label should be same.";  
  CHECK_EQ(bottom[0]->width(), bottom[1]->width())  
      << "The width of data and label should be same.";  
}  
//前向傳導 template
<typename Dtype> void PrecisionRecallLossLayer<Dtype>::Forward_cpu( const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) { const Dtype *data = bottom[0]->cpu_data(); const Dtype *label = bottom[1]->cpu_data();
const int num = bottom[0]->num(); //num和count什么區別 const int dim = bottom[0]->count() / num; const int channels = bottom[0]->channels(); const int spatial_dim = bottom[0]->height() * bottom[0]->width();
//存疑?
const int pnum = this->layer_param_.precision_recall_loss_param().point_num(); top[0]->mutable_cpu_data()[0] = 0;
//對於每個通道
for (int c = 0; c < channels; ++c) { Dtype breakeven = 0.0; Dtype prec_diff = 1.0; for (int p = 0; p <= pnum; ++p) { int true_positive = 0; //統計每類的個數 int false_positive = 0; int false_negative = 0; int true_negative = 0;
for (int i = 0; i < num; ++i) { const Dtype thresh = 1.0 / pnum * p; //計算閾值? for (int j = 0; j < spatial_dim; ++j) {
//取得相應的值和標簽
const Dtype data_value = data[i * dim + c * spatial_dim + j]; const int label_value = (int)label[i * dim + c * spatial_dim + j];
//統計
if (label_value == 1 && data_value >= thresh) { ++true_positive; } if (label_value == 0 && data_value >= thresh) { ++false_positive; } if (label_value == 1 && data_value < thresh) { ++false_negative; } if (label_value == 0 && data_value < thresh) { ++true_negative; } } }
//計算precision和recall Dtype precision
= 0.0; Dtype recall = 0.0; if (true_positive + false_positive > 0) { precision = (Dtype)true_positive / (Dtype)(true_positive + false_positive); } else if (true_positive == 0) { //都是負類? precision = 1.0; } if (true_positive + false_negative > 0) { recall = (Dtype)true_positive / (Dtype)(true_positive + false_negative); } else if (true_positive == 0) { recall = 1.0; } if (prec_diff > fabs(precision - recall) //如果二c者相差小 && precision > 0 && precision < 1 && recall > 0 && recall < 1) { breakeven = precision; //保留 prec_diff = fabs(precision - recall); } } top[0]->mutable_cpu_data()[0] += 1.0 - breakeven; //計算誤差 } top[0]->mutable_cpu_data()[0] /= channels; //??? } //反向 template <typename Dtype> void PrecisionRecallLossLayer<Dtype>::Backward_cpu( const vector<Blob<Dtype>*> &top, const vector<bool> &propagate_down, const vector<Blob<Dtype>*> &bottom) { for (int i = 0; i < propagate_down.size(); ++i) { if (propagate_down[i]) { NOT_IMPLEMENTED; } } } #ifdef CPU_ONLY STUB_GPU(PrecisionRecallLossLayer); #endif //注冊該層 INSTANTIATE_CLASS(PrecisionRecallLossLayer); REGISTER_LAYER_CLASS(PrecisionRecallLoss); } // namespace caffe

 

  1. template <typename Dtype>  
  2. void PrecisionRecallLossLayer<Dtype>::Forward_cpu(  
  3.   const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {  
  4.   const Dtype *data = bottom[0]->cpu_data();  
  5.   const Dtype *label = bottom[1]->cpu_data();  
  6.   const int num = bottom[0]->num();  
  7.   const int dim = bottom[0]->count() / num;  
  8.   const int channels = bottom[0]->channels();  
  9.   const int spatial_dim = bottom[0]->height() * bottom[0]->width();  
  10.   const int pnum =  
  11.     this->layer_param_.precision_recall_loss_param().point_num();  
  12.   top[0]->mutable_cpu_data()[0] = 0;  
  13.   for (int c = 0; c < channels; ++c) {  
  14.     Dtype breakeven = 0.0;  
  15.     Dtype prec_diff = 1.0;  
  16.     for (int p = 0; p <= pnum; ++p) {  
  17.       int true_positive = 0;  
  18.       int false_positive = 0;  
  19.       int false_negative = 0;  
  20.       int true_negative = 0;  
  21.       for (int i = 0; i < num; ++i) {  
  22.         const Dtype thresh = 1.0 / pnum * p;  
  23.         for (int j = 0; j < spatial_dim; ++j) {  
  24.           const Dtype data_value = data[i * dim + c * spatial_dim + j];  
  25.           const int label_value = (int)label[i * dim + c * spatial_dim + j];  
  26.           if (label_value == 1 && data_value >= thresh) {  
  27.             ++true_positive;  
  28.           }  
  29.           if (label_value == 0 && data_value >= thresh) {  
  30.             ++false_positive;  
  31.           }  
  32.           if (label_value == 1 && data_value < thresh) {  
  33.             ++false_negative;  
  34.           }  
  35.           if (label_value == 0 && data_value < thresh) {  
  36.             ++true_negative;  
  37.           }  
  38.         }  
  39.       }  
  40.       Dtype precision = 0.0; 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM