Caffe4——計算圖像均值


Caffe4——計算圖像均值

均值削減是數據預處理中常見的處理方式,按照之前在學習ufldl教程PCA的一章時,對於圖像介紹了兩種:第一種常用的方式叫做dimension_mean(個人命名),是依據輸入數據的維度,每個維度內進行削減,這個也是常見的做法;第二種叫做per_image_mean,ufldl教程上說,在natural images上訓練網絡時;給每個像素(這里只每個dimension)計算一個獨立的均值和方差是make little sense的;這是因為圖像本身具有統計不變性,即在圖像的一部分的統計特性和另一部分相同。作者最后建議,如果你訓練你的算法在非natural images(如mnist,或者在白背景存在單個獨立的物體),其他類型的規則化是值得考慮的。但是當在natural images上訓練時,per_image_mean是一個合理的默認選擇。

本文中在imagenet數據集上采用的是dimension_mean的方法。

 

一:程序開始

make_image_mean.sh文件調用代碼:

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. EXAMPLE=examples/imagenet  
  2. DATA=data/ilsvrc12  
  3. TOOLS=build/tools  
  4. $TOOLS/compute_image_mean $EXAMPLE/ilsvrc12_train_lmdb \  
  5. $DATA/imagenet_mean.binaryproto<strong>  
  6. </strong>  

二:make_image_mean.cpp函數分析

輸入參數:lmdb文件 均值文件imagenet_mean.binaryproto

2.1 頭文件分析

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. #include<stdint.h>//定義了幾種擴展的整數類型和宏  
  2. #include<algorithm>//輸出數組的內容、對數組進行排序、反轉數組內容、復制數組內容等操作,  
  3. #include<string>  
  4. #include<utility>//utility頭文件定義了一個pair類型,pair類型用於存儲一對數據;它也提供一些常用的便利函數、或類、或模板。大小求值、值交換:min、max和swap。  
  5. #include<vector>//可以自動擴展容量的數組  
  6.   
  7. #include"boost/scoped_ptr.hpp"  
  8. #include"gflags/gflags.h"  
  9. #include"glog/logging.h"  
  10.   
  11. #include"caffe/proto/caffe.pb.h"  
  12. #include"caffe/util/db.hpp"//引入包裝好的lmdb操作函數  
  13. #include"caffe/util/io.hpp"//引入opencv中的圖像操作函數  
  14. usingnamespacecaffe;  //引入caffe命名空間  
  15. usingstd::max;//  
  16. usingstd::pair;  
  17. using boost::scoped_ptr;  

2.2 gflags宏定義string變量

DEFINE_string(backend, "lmdb","The backend {leveldb, lmdb} containing theimages");

2.3 main函數分析

2.3.1 lmdb數據操作

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. scoped_ptr<db::DB>db(db::GetDB(FLAGS_backend));  
  2. db->Open(argv[1], db::READ);//只讀的方式打開lmdb文件  
  3. scoped_ptr<db::Cursor> cursor(db->NewCursor());  
  4. //lmdb數據庫的“光標”文件,一個光標保存一個從數據庫根目錄到數據庫文件的路徑;A cursorholds a path of (page pointer, key index) from the DB root to a position in theDB, plus other state.   
2.3.4 聲明中轉對象變量

 

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. BlobProtosum_blob;//聲明blob變量;這個BlobProto在哪里定義的,沒有找到;感覺應該在caffe.pb.h中定義的,因為db.cpp和io.cpp中沒有找到  
  2. int count = 0;  
  3. // load first datum  
  4.   Datum datum;  
  5. datum.ParseFromString(cursor->value());//這個cursor.value,感覺返回的應該是lmdb中存儲的第一個鍵值對數據  
2.3.5 給BlobProto類型變量賦值

 

每個blob對象,為一個4維的數組,分別為image_num*channels*height*width

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. sum_blob.set_num(1);//設置圖片的個數  
  2. sum_blob.set_channels(datum.channels());  
  3. sum_blob.set_height(datum.height());  
  4. sum_blob.set_width(datum.width());  
  5. constintdata_size = datum.channels() *datum.height() * datum.width();//每張圖片的尺寸  
  6. intsize_in_datum = std::max<int>(datum.data().size(),datum.float_data_size());  
這個size()和float_data_size()有些不明白,圖像數據正常應該是整形的數據(例如uint8_t),感覺這個size()應該對應的是整型數據的個數,例如一個50*50的彩色圖片,最后應該是50*50*3=750個整型數來表示一幅50*50的圖片;至於這個float_data_size()就不清楚了,感覺是某些圖片數據使用float類型存儲的,所以用float來統計數值的個數。開始感覺這個float的size應該是把int類型轉換成float后,查看在float類型下的字節占用情況;但是由下面的代碼來看,感覺這個size(),統計的是數據的個數也就是750,而不是占用的字節數。如果圖像使用int類型存儲的,那么float_data_size()=0;如果使用float類型存儲的,那么datum.data.size=0。所以每次都要max操作

 

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. for (inti= 0; i<size_in_datum; ++i) {  
  2. sum_blob.add_data(0.);//設置初值為float型的0.0  
  3.  }  
2.3.6利用循環和cursor讀取lmdb中的數據

 

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. while (cursor->valid()) {//如果cursor是有效的  
  2.     Datum datum;  
  3. datum.ParseFromString(cursor->value());//解析cuisor.value返回的字符串值,到datum  
  4. DecodeDatumNative(&datum);//感覺是把datum中字符串類型的值,變成相應的類型  
  5. conststd::string& data =datum.data();//利用data來引用datum.data  
  6. size_in_datum = std::max<int>(datum.data().size(),datum.float_data_size());  
  7.     CHECK_EQ(size_in_datum,data_size) <<"Incorrect data field size"<<size_in_datum;  
  8. if (data.size() != 0) {//datum.data().size()!=0  
  9.       CHECK_EQ(data.size(),size_in_datum);//判斷是否相等  
  10. for (inti= 0; i<size_in_datum; ++i) {  
  11. sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);//對應位置的像素值相加(uin8_t類型相加),相加的結果放在sum_blob中  
  12.       }  
  13.     } else{  
  14.      CHECK_EQ(datum.float_data_size(), size_in_datum);  
  15. for (inti= 0; i<size_in_datum; ++i) {  
  16. sum_blob.set_data(i, sum_blob.data(i) +  
  17. static_cast<float>(datum.float_data(i)));//對應位置的像素值相加(float類型相加)  
  18.       }  
  19.     }  
  20.     ++count;  
  21. if (count % 10000 == 0) {  
  22. LOG(INFO) <<"Processed "<<count <<" files.";  
  23.     }  
  24.     cursor->Next();//光標下移(指針),指向下一個存儲在lmdb中的數據  
  25.   }  
2.3.7 求均值

 

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. for (inti= 0; i<sum_blob.data_size(); ++i) {  
  2. sum_blob.set_data(i, sum_blob.data(i) / count);  
  3.   }  
2.3.8 存儲到指定文件

 

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. // Write to disk  
  2. if (argc == 3) {  
  3. LOG(INFO) <<"Write to "<<argv[2];  
  4. WriteProtoToBinaryFile(sum_blob, argv[2]);  
  5.   }  
2.3.9 計算每個channel的均值,這個貌似沒有用到吧!

 

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. constint channels = sum_blob.channels();  
  2. constint dim = sum_blob.height() *sum_blob.width();  
  3. std::vector<float>mean_values(channels,0.0);//容量為3的數組,初始值為0.0  
  4. LOG(INFO) <<"Number of channels:"<< channels;  
  5. for (intc = 0; c < channels; ++c) {  
  6. for (inti= 0; i< dim; ++i) {  
  7. mean_values[c] += sum_blob.data(dim * c + i);  
  8.     }  
  9. LOG(INFO) <<"mean_value channel["<< c <<"]:"<<mean_values[c]/ dim;  
  10.   }  
三,相關文件

 

compute_image_mean.cpp

 

[cpp]  view plain copy 在CODE上查看代碼片 派生到我的代碼片
 
  1. #include <stdint.h>  
  2. #include <algorithm>  
  3. #include <string>  
  4. #include <utility>  
  5. #include <vector>  
  6.   
  7. #include "boost/scoped_ptr.hpp"  
  8. #include "gflags/gflags.h"  
  9. #include "glog/logging.h"  
  10.   
  11. #include "caffe/proto/caffe.pb.h"  
  12. #include "caffe/util/db.hpp"  
  13. #include "caffe/util/io.hpp"  
  14.   
  15. using namespace caffe;  // NOLINT(build/namespaces)  
  16.   
  17. using std::max;  
  18. using std::pair;  
  19. using boost::scoped_ptr;  
  20.   
  21. DEFINE_string(backend, "lmdb",  
  22.         "The backend {leveldb, lmdb} containing the images");  
  23.   
  24. int main(int argc, char** argv) {  
  25.   ::google::InitGoogleLogging(argv[0]);  
  26.   
  27. #ifndef GFLAGS_GFLAGS_H_  
  28.   namespace gflags = google;  
  29. #endif  
  30.   
  31.   gflags::SetUsageMessage("Compute the mean_image of a set of images given by"  
  32.         " a leveldb/lmdb\n"  
  33.         "Usage:\n"  
  34.         "    compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n");  
  35.   
  36.   gflags::ParseCommandLineFlags(&argc, &argv, true);  
  37.   
  38.   if (argc < 2 || argc > 3) {  
  39.     gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean");  
  40.     return 1;  
  41.   }  
  42.   
  43.   scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));  
  44.   db->Open(argv[1], db::READ);  
  45.   scoped_ptr<db::Cursor> cursor(db->NewCursor());  
  46.   
  47.   BlobProto sum_blob;  
  48.   int count = 0;  
  49.   // load first datum  
  50.   Datum datum;  
  51.   datum.ParseFromString(cursor->value());  
  52.   
  53.   if (DecodeDatumNative(&datum)) {  
  54.     LOG(INFO) << "Decoding Datum";  
  55.   }  
  56.   
  57.   sum_blob.set_num(1);  
  58.   sum_blob.set_channels(datum.channels());  
  59.   sum_blob.set_height(datum.height());  
  60.   sum_blob.set_width(datum.width());  
  61.   const int data_size = datum.channels() * datum.height() * datum.width();  
  62.   int size_in_datum = std::max<int>(datum.data().size(),datum.float_data_size());  
  63.   for (int i = 0; i < size_in_datum; ++i) {  
  64.     sum_blob.add_data(0.);//設置初值為float型的0.0  
  65.   }  
  66.   LOG(INFO) << "Starting Iteration";  
  67.   while (cursor->valid()) {//如果cursor是有效的  
  68.     Datum datum;  
  69.     datum.ParseFromString(cursor->value());//解析cuisor.value返回的字符串值,到datum  
  70.     DecodeDatumNative(&datum);  
  71.   
  72.     const std::string& data = datum.data();//利用data來引用datum.data  
  73.     size_in_datum = std::max<int>(datum.data().size(),datum.float_data_size());  
  74.     CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<size_in_datum;  
  75.     if (data.size() != 0) {  
  76.       CHECK_EQ(data.size(), size_in_datum);  
  77.       for (int i = 0; i < size_in_datum; ++i) {  
  78.         sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);  
  79.       }  
  80.     } else {  
  81.       CHECK_EQ(datum.float_data_size(), size_in_datum);  
  82.       for (int i = 0; i < size_in_datum; ++i) {  
  83.         sum_blob.set_data(i, sum_blob.data(i) +  
  84.             static_cast<float>(datum.float_data(i)));  
  85.       }  
  86.     }  
  87.     ++count;  
  88.     if (count % 10000 == 0) {  
  89.       LOG(INFO) << "Processed " << count << " files.";  
  90.     }  
  91.     cursor->Next();  
  92.   }  
  93.   
  94.   if (count % 10000 != 0) {  
  95.     LOG(INFO) << "Processed " << count << " files.";  
  96.   }  
  97.   for (int i = 0; i < sum_blob.data_size(); ++i) {  
  98.     sum_blob.set_data(i, sum_blob.data(i) / count);  
  99.   }  
  100.   // Write to disk  
  101.   if (argc == 3) {  
  102.     LOG(INFO) << "Write to " << argv[2];  
  103.     WriteProtoToBinaryFile(sum_blob, argv[2]);  
  104.   }  
  105.   const int channels = sum_blob.channels();  
  106.   const int dim = sum_blob.height() * sum_blob.width();  
  107.   std::vector<float> mean_values(channels, 0.0);  
  108.   LOG(INFO) << "Number of channels: " << channels;  
  109.   for (int c = 0; c < channels; ++c) {  
  110.     for (int i = 0; i < dim; ++i) {  
  111.       mean_values[c] += sum_blob.data(dim * c + i);  
  112.     }  
  113.     LOG(INFO) << "mean_value channel [" << c << "]:" << mean_values[c] / dim;  
  114.   }  
  115.   return 0;  
  116. }  
四:以上代碼注釋為個人理解,如有遺漏,錯誤還望大家多多交流,指正,以便共同學習,進步!!
轉載請標明出處:http://blog.csdn.net/whiteinblue/article/details/45540301


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM