Pytorch的tensor數據類型


基本類型

torch.Tensor是一種包含單一數據類型元素的多維矩陣。

Torch定義了七種CPU tensor類型和八種GPU tensor類型:

Data tyoe CPU tensor GPU tensor
32-bit floating point torch.FloatTensor torch.cuda.FloatTensor
64-bit floating point torch.DoubleTensor torch.cuda.DoubleTensor
16-bit floating point N/A torch.cuda.HalfTensor
8-bit integer (unsigned) torch.ByteTensor torch.cuda.ByteTensor
8-bit integer (signed) torch.CharTensor torch.cuda.CharTensor
16-bit integer (signed) torch.ShortTensor torch.cuda.ShortTensor
32-bit integer (signed) torch.IntTensor torch.cuda.IntTensor
64-bit integer (signed) torch.LongTensor torch.cuda.LongTensor

torch.DoubleTensor(2, 2) 構建一個22 Double類型的張量
torch.ByteTensor(2, 2) 構建一個2
2 Byte類型的張量
torch.CharTensor(2, 2) 構建一個22 Char類型的張量
torch.ShortTensor(2, 2) 構建一個2
2 Short類型的張量
torch.IntTensor(2, 2) 構建一個22 Int類型的張量
torch.LongTensor(2, 2) 構建一個2
2 Long類型的張量

類型轉換

2.1 CPU和GPU的Tensor之間轉換

從cpu –> gpu,使用data.cuda()即可。
若從gpu –> cpu,則使用data.cpu()。

2.2 Tensor與Numpy Array之間的轉換

Tensor –> Numpy.ndarray 可以使用 data.numpy(),其中data的類型為torch.Tensor。
Numpy.ndarray –> Tensor 可以使用torch.from_numpy(data),其中data的類型為numpy.ndarray。

2.3 Tensor的基本類型轉換(也就是float轉double,轉byte這種。)

為了方便測試,我們構建一個新的張量,你要轉變成不同的類型只需要根據自己的需求選擇即可

  1. tensor = torch.Tensor(2, 5)

  2. torch.long() 將tensor投射為long類型
    newtensor = tensor.long()

  3. torch.half()將tensor投射為半精度浮點(16位浮點)類型
    newtensor = tensor.half()

  4. torch.int()將該tensor投射為int類型
    newtensor = tensor.int()

  5. torch.double()將該tensor投射為double類型
    newtensor = tensor.double()

  6. torch.float()將該tensor投射為float類型
    newtensor = tensor.float()

  7. torch.char()將該tensor投射為char類型
    newtensor = tensor.char()

  8. torch.byte()將該tensor投射為byte類型
    newtensor = tensor.byte()

  9. torch.short()將該tensor投射為short類型
    newtensor = tensor.short()

如果當你需要提高精度,比如說想把模型從float變為double。那么可以將要訓練的模型設置為model = model.double()。此外,還要對所有的張量進行設置:pytorch.set_default_tensor_type('torch.DoubleTensor'),不過double比float要慢很多,要結合實際情況進行思考。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM