Pytorch 类别平衡化处理


采用  WeightedRandomSampler:

def make_weights_for_balanced_classes(images, nclasses): count = [0] * nclasses for item in images: count[item[1]] += 1 weight_per_class = [0.] * nclasses N = float(sum(count)) for i in range(nclasses): weight_per_class[i] = N/float(count[i]) weight = [0] * len(images) for idx, val in enumerate(images): weight[idx] = weight_per_class[val[1]] return weight      
dataset_train = datasets.ImageFolder(traindir) # For unbalanced dataset we create a weighted sampler 
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes)) weights = torch.DoubleTensor(weights) sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True, sampler = sampler, num_workers=args.workers, pin_memory=True)     

Reference:   Balanced Sampling between classes with torchvision DataLoader

 

参考方法2: 作者给出了均匀采样和非均匀采样的差别

imbalanced-dataset-sampler


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM