pytorch構建自己的數據集


現在需要在json文件里面讀取圖片的URL和label,這里面可能會出現某些URL地址無效的情況。

python讀取json文件

此處只需要將json文件里面的內容讀取出來就可以了

with open("json_path",'r') ad load_f:
    load_dict = json.load(load_f)

json_path是json文件的地址,json文件里面的內容讀取到load_dict變量中,變量類型為字典類型。

python通過URL打開圖片

通過skimage獲取URL圖片是簡單的方式。

from skimage import io
image = io.imread(img_src) # img_src是圖片的URL
io.imshow(image)
io.show()

pytorch構建自己的數據集

pytorch中文網中有比較好的講解: https://ptorch.com/news/215.html

加載圖片預處理以及可視化見: https://oldpan.me/archives/pytorch-transforms-opencv-scikit-image

定義自己的數據集使用類 torch.utils.data.Dataset這個類,這個類中有三個關鍵的默認成員函數,__init__,__len__,__getitem__。

__init__類實例化應用,所以參數項里面最好有數據集的path,或者是數據以及標簽保存的json、csv文件,在__init__函數里面對json、csv文件進行解析。

__len__需要返回images的數量。

__getitem__中要返回image和相對應的label,要注意的是此處參數有一個index,指的返回的是哪個image和label。

 

import torch
from torchvision import transforms 
import json
import os
from PIL import Image


class ProductDataset(torch.utils.data.Dataset):
    def __init__(self,json_path,data_path,transform = None,train = True):
        with open(json_path,'r') as load_f:
            self.json_dict = json.load(load_f)
        self.json_dict = self.json_dict["images"]
        self.train = train
        self.data_path = data_path
        self.transform = transform

    def __len__(self):
        return len(self.json_dict)

    def __getitem__(self,index):
        image_id = os.path.join(self.data_path + '/',str(self.json_dict[index]["id"]))
        image = Image.open(image_id)
        image = image.convert('RGB')
        label = int(self.json_dict[index]["class"])
        if self.transform:
            image = self.transform(image)
        if self.train:
            return image,label
        else:
            image_id = self.json_dict[index]["id"]
            return image,label,image_id


if __name__ == '__main__':
    val_dataset = ProductDataset('data/FullImageTrain.json','data/train',train=False,
                                transform=transforms.Compose([
                                    transforms.Pad(4),
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                ]))
    kwargs = {'num_workers': 4, 'pin_memory': True}
    test_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                batch_size=32,
                                                shuffle=False,
                                                **kwargs)

    print(val_dataset.__len__())
    count = 0
    for image,label,image_id in test_loader:
        print(image.shape,count)
        count += 1

 

關於transform,圖像預處理的各個函數功能介紹如下:

torch.transforms是常見的圖像變換,可以用Compose連接起來。

下面是Transforms on PIL Image:

transforms.CenterCrop(size):

size可以是一個像(h,w)的sequence,這樣輸出的是一個中心裁剪的(h,w)圖像。

transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):

隨機更改圖像的亮度,對比度和飽和度。

傳遞的參數是float型變量或者是tuple(元素是float型)型變量,如果是tuple型變量,第一個元素是min值,第二個元素是max值。

transforms.Grayscale(num_output_channels=1)

將Image轉換為灰度值

transforms.Pad(padding, fill=0, padding_mode='constant')

padding這個參數,如果給定的是單個的值,那么會pad所有的邊。

transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')

隨機裁剪圖片到給定尺寸

size如果是(h,w)這樣的sequence,那么將剪出一個(h,w)大小的圖片

transforms.RandomHorizontalFlip(p=0.5):

以給定的概率隨機水平翻轉給定的PIL圖像。

transforms.RandomResizedCrop(size,scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)

將給定的圖像隨機裁剪為不同的大小和高寬比,然后縮放所裁剪的圖像到指定大小。

 

該操作的含義:即使只是該物體的一部分,我們也認為這是該類物體。

scale為0.08到1的意思為裁剪的面積比例為0.08到1,注意是面積不是邊,ratio是高寬比。
transforms.Resize(size, interpolation=2):

Resize給定的Image圖像到指定大小。

size:給定圖像大小

interpolation:差值方法,默認是PIL.Image.BILINEAR

下面是Transforms on torch.*Tensor:

transforms.Normalize(mean,var,inplace=False):

標准化圖像,mean和var給定三個值的情況下,是分別對於RGB三個channel進行標准化。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM