torch.cuda.FloatTensor


Pytorch中的tensor又包括CPU上的數據類型和GPU上的數據類型,一般GPU上的Tensor是CPU上的Tensor加cuda()函數得到。

一般系統默認是torch.FloatTensor類型。例如data = torch.Tensor(2,3)是一個2*3的張量,類型為FloatTensor;

data.cuda()就轉換為GPU的張量類型,torch.cuda.FloatTensor類型。

if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM