pytorch-使用torch.utils.data.DataLoader, __iter__, ___getitem__来模拟batch数据的处理过程


使用__iter__,  __getitem__来模拟数据处理部分 

import torch.utils.data
class Model():
    def __init__(self, animal_list):
        self.animal_list = animal_list
    # 根据迭代batch_size进行返回
    def __getitem__(self, index):
        root = {'A': self.animal_list[index], 'B': 1}
        return root

    def __len__(self):
        return len(self.animal_list)

class Animal:
    def __init__(self, animal_list):
        self.animals_name = animal_list
        self.m = Model(self.animals_name)
        self.model = torch.utils.data.DataLoader(
            self.m, # 构造两个self.m的输出结果 
            batch_size=2,
            shuffle=True # idx 是随机值 
        )
    def __iter__(self):
        for i, data in enumerate(self.model):

            yield data


animals = Animal(['dog', 'cat', 'fish'])

for i, animal in enumerate(animals):
    print(animal)

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM