torch.cuda.FloatTensor


Pytorch中的tensor又包括CPU上的数据类型和GPU上的数据类型,一般GPU上的Tensor是CPU上的Tensor加cuda()函数得到。

一般系统默认是torch.FloatTensor类型。例如data = torch.Tensor(2,3)是一个2*3的张量,类型为FloatTensor;

data.cuda()就转换为GPU的张量类型,torch.cuda.FloatTensor类型。

if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM