Pytorch中的tensor又包括CPU上的數據類型和GPU上的數據類型,一般GPU上的Tensor是CPU上的Tensor加cuda()函數得到。
一般系統默認是torch.FloatTensor類型。例如data = torch.Tensor(2,3)是一個2*3的張量,類型為FloatTensor;
data.cuda()就轉換為GPU的張量類型,torch.cuda.FloatTensor類型。
if cuda: dtype = torch.cuda.FloatTensor else: if torch.cuda.is_available(): print("WARNING: You have a CUDA device, so you should probably set cuda=True") dtype = torch.FloatTensor