device


Ref:CSDN

device = torch.device('cuda' if (self.worker == 'gpu' and torch.cuda.is_available()) else 'cpu')

if torch.cuda.device_count() > 1:  # 多gpu
    model = torch.nn.DataParallel(model, device_ids=[x for x in range(self.config.gpu_num)])

几个需要添加to.device的地方

  1. model(如:model.to(device))
  2. input(通常需要使用Variable包装,如:input = Variable(input).to(device))
  3. target(通常需要使用Variable包装,如:target = Variable(torch.from_numpy(np.array(target)).long()).to(device)
  4. nn.CrossEntropyLoss()(如:criterion = nn.CrossEntropyLoss().to(device))


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM