Pytorch數據讀取框架


訓練一個模型需要有一個數據庫,一個網絡,一個優化函數。數據讀取是訓練的第一步,以下是pytorch數據輸入框架。

1)實例化一個數據庫

假設我們已經定義了一個FaceLandmarksDataset數據庫,此數據庫將在以下建立。

import FaceLandmarksDataset
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/',
                                    transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor()]) )

 

或者使用torchvision.datasets里封裝的數據集(MNIST、Fashion-MNIST、KMNIST、EMNIST、COCO、LSUN、ImageFolder、DatasetFolder、Imagenet-12、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes)

import torchvision.datasets
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')

2)創建一個數據加載器

import torch.utils.data.DataLoader
imagenet_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,  
                                          shuffle=True,
                                          num_workers=4)
#or

facelandmark_loader = torch.utils.data.DataLoader(face_dataset,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=4) 

可見,數據加載器是通用的,只有數據庫實例不一樣,其它的都參數都一樣,參數值可以根據任務需要自己調。

3)使用數據庫

數據加載器可迭代的,我們可以使用數據庫:

for item in facelandmark_loader:
     images,labels = item
do_somethi

當然, 我們也可以直接對數據庫實例face_dataset進行下標操作,但這樣只能夠每次獲取一條數據。

sample = face_dataset[index]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM