PyTorch ImageFolder自定義數據集


PyTorch自定義數據集,我們介紹了如何通過重寫Dataset類來自定義數據集,但其實對於圖像數據,自定義數據集有一個更簡單的方法,那就是直接調用ImageFolder,它是torchvision.datasets里的函數。

ImageFolder介紹

ImageFolder假設所有的文件按文件夾保存,每個文件夾下存儲同一個類別的圖片,文件夾名為類名,其構造函數如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

各參數含義:

root:在root指定的路徑下尋找圖片

transform:對PIL Image進行的轉換操作,transform的輸入是使用loader讀取圖片的返回對象

target_transform:對label的轉換

loader:給定路徑后如何讀取圖片,默認讀取為RGB格式的PIL Image對象

label是按照文件夾名順序排序后存成字典,即{類名:類序號(從0開始)}

示例

從kaggle官網下載dogsVScats的數據集(百度網盤的下載鏈接見文末),該數據集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號:

cat.0.jpg
cat.1.jpg
cat.2.jpg
...
cat.12499.jpg
dog.0.jpg
dog.1.jpg
dog.2.jpg
...
dog.12499.jpg

假設我們希望把train文件夾中90%貓的圖片和90%狗的圖片作為訓練集,剩下的10%作為驗證集:

import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import transforms, datasets

# kaggle原始數據集在本地電腦的文件路徑
original_dataset_dir = '/Users/wangpeng/Desktop/all/CS/Datasets/kaggle_dogs_cats/train'
total_num = int(len(os.listdir(original_dataset_dir)) / 2)
random_idx = np.array(range(total_num))
np.random.shuffle(random_idx)

# 待處理的數據集地址
base_dir = '/Users/wangpeng/Desktop/dogsVScats'
if not os.path.exists(base_dir):
    os.mkdir(base_dir)

# 訓練集、驗證集的划分
sub_dirs = ['train', 'validate']
animals = ['cats', 'dogs']
train_idx = random_idx[:int(total_num * 0.9)]
validate_idx = random_idx[int(total_num * 0.9):]
numbers = [train_idx, validate_idx]
for idx, sub_dir in enumerate(sub_dirs):
    dir = os.path.join(base_dir, sub_dir)
    if not os.path.exists(dir):
        os.mkdir(dir)
    for animal in animals:
        animal_dir = os.path.join(dir, animal)
        if not os.path.exists(animal_dir):
            os.mkdir(animal_dir)
        fnames = [animal[:-1] + '.{}.jpg'.format(i) for i in numbers[idx]]
        for fname in fnames:
            src = os.path.join(original_dataset_dir, fname)
            dst = os.path.join(animal_dir, fname)
            shutil.copyfile(src, dst)

        # 訓練集、驗證集的圖片數目
        print(animal_dir + ' total images : %d' % (len(os.listdir(animal_dir))))

運行上面的程序,在我的電腦的桌面上將會有一個dogsVScats文件夾,其文件結構如下:

dogsVScats
    |
    |----train
    |         |     
    |         |---cats(包含11250張貓的圖片)
    |         |---dogs(包含11250張狗的圖片)
    |   
    |-----validate
              |
              |---cats(包含1250張貓的圖片)
              |---dogs(包含1250張狗的圖片)

接着我們就可以用ImageFolder創建數據集了,並把創建好的數據集放到DataLoader中:

data_transform = transforms.Compose([
    transforms.Resize(256),         # 把圖片resize為256*256
    transforms.CenterCrop(224),     # 隨機裁剪224*224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   # 標准化
])

train_dataset = datasets.ImageFolder(root='/Users/wangpeng/Desktop/dogsVScats/train', transform=data_transform)  # 標簽為{'cats':0, 'dogs':1}
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)

validate_dataset = datasets.ImageFolder(root='/Users/wangpeng/Desktop/dogsVScats/validate', transform=data_transform)  
validate_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

我們可以測試一下,看一下train_loader可不可以用:

if __name__ == '__main__':
    image, label = iter(train_loader).next()  # iter()函數把train_loader變為迭代器,然后調用迭代器的next()方法
    sample = image[0].squeeze()
    sample = sample.permute((1, 2, 0)).numpy()
    sample *= [0.229, 0.224, 0.225]
    sample += [0.485, 0.456, 0.406]
    sample = np.clip(sample, 0, 1)
    plt.imshow(sample)
    plt.show()
    print('Label is: {}'.format(label[0].numpy()))

運行結果:

Label is: 1

同樣的我們可以測試validate_loader,這里就不再贅述了。

 

dogsVScats數據下載鏈接:鏈接:https://pan.baidu.com/s/17768gqeaX9NrdURV_tR_ow  提取密碼:478x

參考文獻

[1] pytorch之ImageFolder使用詳解

[2] pytorch實現kaggle貓狗識別

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM