libtorch (pytorch c++) 教程(四)



title: libtorch教程(四)
date: 2021-01-18 19:50:16
tags: libtorch

本章將詳細介紹如何使用libtorch自帶的數據加載模塊,使用該模塊是實現模型訓練的重要條件。除非這個數據加載模塊功能不夠,不然繼承libtorch的數據加載類還是很有必要的,簡單高效。

使用前置條件

libtorch提供了豐富的基類供用戶自定義派生類,torch::data::Dataset就是其中一個常用基類。使用該類需要明白基類和派生類,以及所謂的繼承和多態。有c++編程經驗者應該都不會陌生,為方便不同階段讀者就簡單解釋一下吧。類就是父親,可以生出不同的兒子,生兒子叫派生或者繼承(看使用語境),生不同的兒子就實現了多態。父親就是基類,兒子就是派生類。現實中,父親會把自身的一部分財產留下來養老,兒子們都不能碰,這就是private了,部分財產兒子能用,但是兒子的對象不能用,這叫protected,還有些財產誰都能用就是public。和現實中的父子類似,代碼中,派生類可以使用父類的部分屬性或者函數,全看父類怎樣定義。

然后理解一下虛函數,就是父親指定了部分財產是public的,但是是用來買房的,不同的兒子可以買不同的房子,可以全款可以貸款,這就是財產在父親那就是virtual的。子類要繼承這個virtual財產可以自己重新規划使用方式。

事實上,如果有過pytorch的編程經驗者很快會發現,libtorch的Dataset類的使用和python下使用非常相像。pytorch自定義dataload,需要定義好Dataset的派生類,包括初始化函數__init__,獲取函數__getitem__以及數據集大小函數__len__。類似的,libtorch中同樣需要處理好初始化函數,get()函數和size()函數。

圖片文件遍歷

下面以分類任務為例,介紹libtorch的Dataset類的使用。使用pytorch官網提供的昆蟲分類數據集,下載到本地解壓。將該數據集根目錄作為索引,實現Dataloader對圖片的加載。

首先定義一個加載圖片的函數,使用網上出現較多的c++遍歷文件夾的代碼,將代碼稍作修改如下:

//遍歷該目錄下的.jpg圖片
void load_data_from_folder(std::string image_dir, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);

void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label)
{
    long long hFile = 0; //句柄
    struct _finddata_t fileInfo;
    std::string pathName;
    if ((hFile = _findfirst(pathName.assign(path).append("\\*.*").c_str(), &fileInfo)) == -1)
    {
        return;
    }
    do
    {
        const char* s = fileInfo.name;
        const char* t = type.data();

        if (fileInfo.attrib&_A_SUBDIR) //是子文件夾
        {
            //遍歷子文件夾中的文件(夾)
            if (strcmp(s, ".") == 0 || strcmp(s, "..") == 0) //子文件夾目錄是.或者..
                continue;
            std::string sub_path = path + "\\" + fileInfo.name;
            label++;
            load_data_from_folder(sub_path, type, list_images, list_labels, label);

        }
        else //判斷是不是后綴為type文件
        {
            if (strstr(s, t))
            {
                std::string image_path = path + "\\" + fileInfo.name;
                list_images.push_back(image_path);
                list_labels.push_back(label);
            }
        }
    } while (_findnext(hFile, &fileInfo) == 0);
    return;
}

修改后的函數接受數據集文件夾路徑image_dir和圖片類型image_type,將遍歷到的圖片路徑和其類別分別存儲到list_images和list_labels,最后lable變量用於表示類別計數。傳入lable=-1,返回的lable值加一后等於圖片類別。

自定義Dataset

定義dataSetClc,該類繼承自torch::data::Dataset。定義私有變量image_paths和labels分別存儲圖片路徑和類別,是兩個vector變量。dataSetClc的初始化函數就是加載圖片和類別。通過get()函數返回由圖像和類別構成的張量列表。可以在get()函數中做任意針對圖像的操作,如數據增強等。效果等價於pytorch中的__getitem__中的數據增強。

class dataSetClc:public torch::data::Dataset<dataSetClc>{
public:
    int class_index = 0;
    dataSetClc(std::string image_dir, std::string type){
        load_data_from_folder(image_dir, std::string(type), image_paths, labels, class_index-1);
    }
    // Override get() function to return tensor at location index
    torch::data::Example<> get(size_t index) override{
        std::string image_path = image_paths.at(index);
        cv::Mat image = cv::imread(image_path);
        cv::resize(image, image, cv::Size(224, 224)); //尺寸統一,用於張量stack,否則不能使用stack
        int label = labels.at(index);
        torch::Tensor img_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }); // Channels x Height x Width
        torch::Tensor label_tensor = torch::full({ 1 }, label);
        return {img_tensor.clone(), label_tensor.clone()};
    }
    // Override size() function, return the length of data
    torch::optional<size_t> size() const override {
        return image_paths.size();
    };
private:
    std::vector<std::string> image_paths;
    std::vector<int> labels;
};

使用自定義的Dataset

下面使用定義好的數據加載類,以昆蟲分類中的訓練集作為測試,代碼如下。可以打印加載的圖片張量和類別。

int batch_size = 2;
std::string image_dir = "your path to\\hymenoptera_data\\train";
auto mdataset = myDataset(image_dir,".jpg").map(torch::data::transforms::Stack<>());
auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(mdataset), batch_size);
for(auto &batch: *mdataloader){
    auto data = batch.data;
    auto target = batch.target;
    std::cout<<data.sizes()<<target;
}

分享不易,如果有用請不吝給我一個👍,轉載注明出處:https://allentdan.github.io/
代碼見LibtorchTutorials


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM