title: libtorch教程(四)
date: 2021-01-18 19:50:16
tags: libtorch
本章將詳細介紹如何使用libtorch自帶的數據加載模塊,使用該模塊是實現模型訓練的重要條件。除非這個數據加載模塊功能不夠,不然繼承libtorch的數據加載類還是很有必要的,簡單高效。
使用前置條件
libtorch提供了豐富的基類供用戶自定義派生類,torch::data::Dataset就是其中一個常用基類。使用該類需要明白基類和派生類,以及所謂的繼承和多態。有c++編程經驗者應該都不會陌生,為方便不同階段讀者就簡單解釋一下吧。類就是父親,可以生出不同的兒子,生兒子叫派生或者繼承(看使用語境),生不同的兒子就實現了多態。父親就是基類,兒子就是派生類。現實中,父親會把自身的一部分財產留下來養老,兒子們都不能碰,這就是private了,部分財產兒子能用,但是兒子的對象不能用,這叫protected,還有些財產誰都能用就是public。和現實中的父子類似,代碼中,派生類可以使用父類的部分屬性或者函數,全看父類怎樣定義。
然后理解一下虛函數,就是父親指定了部分財產是public的,但是是用來買房的,不同的兒子可以買不同的房子,可以全款可以貸款,這就是財產在父親那就是virtual的。子類要繼承這個virtual財產可以自己重新規划使用方式。
事實上,如果有過pytorch的編程經驗者很快會發現,libtorch的Dataset類的使用和python下使用非常相像。pytorch自定義dataload,需要定義好Dataset的派生類,包括初始化函數__init__,獲取函數__getitem__以及數據集大小函數__len__。類似的,libtorch中同樣需要處理好初始化函數,get()函數和size()函數。
圖片文件遍歷
下面以分類任務為例,介紹libtorch的Dataset類的使用。使用pytorch官網提供的昆蟲分類數據集,下載到本地解壓。將該數據集根目錄作為索引,實現Dataloader對圖片的加載。
首先定義一個加載圖片的函數,使用網上出現較多的c++遍歷文件夾的代碼,將代碼稍作修改如下:
//遍歷該目錄下的.jpg圖片
void load_data_from_folder(std::string image_dir, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);
void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label)
{
long long hFile = 0; //句柄
struct _finddata_t fileInfo;
std::string pathName;
if ((hFile = _findfirst(pathName.assign(path).append("\\*.*").c_str(), &fileInfo)) == -1)
{
return;
}
do
{
const char* s = fileInfo.name;
const char* t = type.data();
if (fileInfo.attrib&_A_SUBDIR) //是子文件夾
{
//遍歷子文件夾中的文件(夾)
if (strcmp(s, ".") == 0 || strcmp(s, "..") == 0) //子文件夾目錄是.或者..
continue;
std::string sub_path = path + "\\" + fileInfo.name;
label++;
load_data_from_folder(sub_path, type, list_images, list_labels, label);
}
else //判斷是不是后綴為type文件
{
if (strstr(s, t))
{
std::string image_path = path + "\\" + fileInfo.name;
list_images.push_back(image_path);
list_labels.push_back(label);
}
}
} while (_findnext(hFile, &fileInfo) == 0);
return;
}
修改后的函數接受數據集文件夾路徑image_dir和圖片類型image_type,將遍歷到的圖片路徑和其類別分別存儲到list_images和list_labels,最后lable變量用於表示類別計數。傳入lable=-1,返回的lable值加一后等於圖片類別。
自定義Dataset
定義dataSetClc,該類繼承自torch::data::Dataset。定義私有變量image_paths和labels分別存儲圖片路徑和類別,是兩個vector變量。dataSetClc的初始化函數就是加載圖片和類別。通過get()函數返回由圖像和類別構成的張量列表。可以在get()函數中做任意針對圖像的操作,如數據增強等。效果等價於pytorch中的__getitem__中的數據增強。
class dataSetClc:public torch::data::Dataset<dataSetClc>{
public:
int class_index = 0;
dataSetClc(std::string image_dir, std::string type){
load_data_from_folder(image_dir, std::string(type), image_paths, labels, class_index-1);
}
// Override get() function to return tensor at location index
torch::data::Example<> get(size_t index) override{
std::string image_path = image_paths.at(index);
cv::Mat image = cv::imread(image_path);
cv::resize(image, image, cv::Size(224, 224)); //尺寸統一,用於張量stack,否則不能使用stack
int label = labels.at(index);
torch::Tensor img_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }); // Channels x Height x Width
torch::Tensor label_tensor = torch::full({ 1 }, label);
return {img_tensor.clone(), label_tensor.clone()};
}
// Override size() function, return the length of data
torch::optional<size_t> size() const override {
return image_paths.size();
};
private:
std::vector<std::string> image_paths;
std::vector<int> labels;
};
使用自定義的Dataset
下面使用定義好的數據加載類,以昆蟲分類中的訓練集作為測試,代碼如下。可以打印加載的圖片張量和類別。
int batch_size = 2;
std::string image_dir = "your path to\\hymenoptera_data\\train";
auto mdataset = myDataset(image_dir,".jpg").map(torch::data::transforms::Stack<>());
auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(mdataset), batch_size);
for(auto &batch: *mdataloader){
auto data = batch.data;
auto target = batch.target;
std::cout<<data.sizes()<<target;
}
分享不易,如果有用請不吝給我一個👍,轉載注明出處:https://allentdan.github.io/
代碼見LibtorchTutorials