pytorch-torch參數使用 1.torch.cat(維度串接) 2. torch.backend.cudnn.benchmark(加速優化計算)


1. torch.cat(data, axis) # data表示輸入的數據, axis表示進行串接的維度

t = Test()
t.num = 50
print(t.num)


a = torch.tensor([[1, 1]])
b = torch.tensor([[2, 2]])
x = []
x.append(a) # 維度是[1, 1, 2]
x.append(b) # 維度是[2, 1, 2]

c = torch.cat(x, 0) # 將維度進行串接
print(c.data.numpy().shape)

2. torch.backend.cudnn.benchmark (進行優化加速) 如果每次輸入都是相同的時候,因為需要搜索計算卷積的最佳方式 ,所以在保證維度不變的情況下,可以持續使用最優的計算方法 

  if opt.preprocess != 'scale_width':  # 如果是規則輸入的話,最后的輸入值數量可能低於一個batch_size 
            torch.backends.cudnn.benckmark = True

3. torch.nn.DataParallel (使用多塊GPU進行網絡的訓練)

  if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  #gpu_id = [0, 1, 2, 3]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM