device


Ref:CSDN

device = torch.device('cuda' if (self.worker == 'gpu' and torch.cuda.is_available()) else 'cpu')

if torch.cuda.device_count() > 1:  # 多gpu
    model = torch.nn.DataParallel(model, device_ids=[x for x in range(self.config.gpu_num)])

幾個需要添加to.device的地方

  1. model(如:model.to(device))
  2. input(通常需要使用Variable包裝,如:input = Variable(input).to(device))
  3. target(通常需要使用Variable包裝,如:target = Variable(torch.from_numpy(np.array(target)).long()).to(device)
  4. nn.CrossEntropyLoss()(如:criterion = nn.CrossEntropyLoss().to(device))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM