pytorch-使用torch.utils.data.DataLoader, __iter__, ___getitem__來模擬batch數據的處理過程


使用__iter__,  __getitem__來模擬數據處理部分 

import torch.utils.data
class Model():
    def __init__(self, animal_list):
        self.animal_list = animal_list
    # 根據迭代batch_size進行返回
    def __getitem__(self, index):
        root = {'A': self.animal_list[index], 'B': 1}
        return root

    def __len__(self):
        return len(self.animal_list)

class Animal:
    def __init__(self, animal_list):
        self.animals_name = animal_list
        self.m = Model(self.animals_name)
        self.model = torch.utils.data.DataLoader(
            self.m, # 構造兩個self.m的輸出結果 
            batch_size=2,
            shuffle=True # idx 是隨機值 
        )
    def __iter__(self):
        for i, data in enumerate(self.model):

            yield data


animals = Animal(['dog', 'cat', 'fish'])

for i, animal in enumerate(animals):
    print(animal)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM