PyTorch中的常用的tensor類型
PyTorch中的常用的tensor類型包括:
32位浮點型torch.FloatTensor,
64位浮點型torch.DoubleTensor,
16位整型torch.ShortTensor,
32位整型torch.IntTensor,
64位整型torch.LongTensor。
類型之間的轉換
一般只要在tensor后加long(), int(), double(),float(),byte()等函數就能將tensor進行類型轉換
此外,還可以使用type()函數,data為Tensor數據類型,data.type()為給出data的類型,如果使用data.type(torch.FloatTensor)則強制轉換為torch.FloatTensor類型張量。
a1.type_as(a2)可將a1轉換為a2同類型。
tensor和numpy.array轉換
tensor -> numpy.array: data.numpy(),如:
numpy.array -> tensor: torch.from_numpy(data),如:
CPU張量和GPU張量之間的轉換
CPU -> GPU: data.cuda()
GPU -> CPU: data.cpu()
當需要把一個GPU上的tensor數據(假設叫做output)遷移到CPU上並且轉換為numpy類型時,可以用命令output.detach().cpu().numpy()
參考資料:
[1] Pytorch變量類型轉換