(十)pytorch多線程訓練,DataLoader的num_works參數設置


一、概述

數據集較小時(小於2W)建議num_works不用管默認就行,因為用了反而比沒用慢。
當數據集較大時建議采用,num_works一般設置為(CPU線程數+-1)為最佳,可以用以下代碼找出最佳num_works(注意windows用戶如果要使用多核多線程必須把訓練放在if __name__ == '__main__':下才不會報錯)

二、代碼

import time
import torch.utils.data as d
import torchvision
import torchvision.transforms as transforms
 
 
if __name__ == '__main__':
    BATCH_SIZE = 100
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST('\mnist', download=False, train=True, transform=transform)
    
    # data loaders
    train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    
    for num_workers in range(20):
        train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
        # training ...
        start = time.time()
        for epoch in range(1):
            for step, (batch_x, batch_y) in enumerate(train_loader):
                pass
        end = time.time()
        print('num_workers is {} and it took {} seconds'.format(num_workers, end - start))

 三、查看線程數

1、cpu個數

grep 'physical id' /proc/cpuinfo | sort -u

2、核心數

grep 'core id' /proc/cpuinfo | sort -u | wc -l

3、線程數

grep 'processor' /proc/cpuinfo | sort -u | wc -l

4、例子

命令執行結果如圖所示,根據結果得知,此服務器有1個cpu,6個核心,每個核心2線程,共12線程。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM