一、PyTorch 入門實戰—Tensor(轉)


一、Tensor的創建和使用

1.概念和TensorFlow的是基本一致的,只是代碼編寫格式的不同。我們聲明一個Tensor,並打印它,例如:

import torch #定義一個Tensor矩陣
a = torch.Tensor([1, 2], [3, 4],[5, 6], [7, 8]) print(a) print('{}'.format(a))

然后會發現報以下錯誤:

new() received an invalid combination of arguments - got (list, list, list, list), but expected one of: * (torch.device device) * (torch.Storage storage) * (Tensor other) * (tuple of ints size, torch.device device) * (object data, torch.device device)                      

意思是接收到無效的參數組合。其實是少寫了一對中括號,這是初學者的常用錯誤。

2.改成如下形式:

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(a) print('{}'.format(a))

結果為:

                                               

3.如果想查看的它的大小可以加一句話:

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print('{}'.format(a.size()))

結果為:

                                                       

即4行2列的矩陣

4.如果想生成一個全為0的矩陣,可以輸入如下代碼:

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print('{}'.format(a.size())) b = torch.zeros((4, 2)) print(b)

結果為:

                                                          

即4行2列數組元素全為0的矩陣

5.如果想生成不同類型的數據,可以改變torch.后面函數名稱,例如下面這樣:

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print('{}'.format(a)) b = torch.zeros((4, 2)) print(b) c = torch.IntTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(c) d = torch.LongTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(d) e = torch.DoubleTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(e)

結果為:

                                                  

6.如果想訪問Tensor里的一個元素或者改變它,可以輸入如下代碼:

print(e[1, 1]) #改變元素值
e[1, 1] = 3
print(e[1, 1])

代碼變為: 

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print('{}'.format(a)) b = torch.zeros((4, 2)) print(b) c = torch.IntTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(c) d = torch.LongTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(d) e = torch.DoubleTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(e) print(e[1, 1]) #改變元素值
e[1, 1] = 3
print(e[1, 1])

結果為:

                                                          

說明原來4的位置數值變為了3

7.最重要的是Tensor和Numpy之間的轉換,例如我們把e變為numpy類型,添加以下代碼:

f = e.numpy() print(f)

變為:

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print('{}'.format(a)) b = torch.zeros((4, 2)) print(b) c = torch.IntTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(c) d = torch.LongTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(d) e = torch.DoubleTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(e) print(e[1, 1]) #改變元素值
e[1, 1] = 3
print(e[1, 1]) #轉換為Numpy
f = e.numpy() print(f)

結果為: 

                                                           

可以看到沒有tensor()了~

我們再把f變為tensor類型,輸入以下代碼:

g = torch.from_numpy(f) print(g)

變為:

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print('{}'.format(a)) b = torch.zeros((4, 2)) print(b) c = torch.IntTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(c) d = torch.LongTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(d) e = torch.DoubleTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(e) print(e[1, 1]) #改變元素值
e[1, 1] = 3
print(e[1, 1]) #轉換為Numpy
f = e.numpy() print(f) #轉換為Tensor
g = torch.from_numpy(f) print(g)

結果為:

                                                             

可以看到又變成了Tensor類型

二、Tensor放到GPU上執行

1.通過如下代碼判斷是否支持GPU:

if torch.cuda.is_available(): h = g.cuda() print(h)

變為

import torch #定義一個Tensor矩陣
a = torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print('{}'.format(a)) b = torch.zeros((4, 2)) print(b) c = torch.IntTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(c) d = torch.LongTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(d) e = torch.DoubleTensor([[1, 2], [3, 4],[5, 6], [7, 8]]) print(e) print(e[1, 1]) #改變元素值
e[1, 1] = 3
print(e[1, 1]) #轉換為Numpy
f = e.numpy() print(f) #轉換為Tensor
g = torch.from_numpy(f) print(g) #將Tensor放在GPU上
if torch.cuda.is_available(): h = g.cuda() print(h)

2.生成結果會慢一下,然后可以看到多了一個device=‘cuda:0’:

                                             

三、Tensor總結

1.Tensor和Numpy都是矩陣,區別是前者可以在GPU上運行,后者只能在CPU上

2.Tensor和Numpy互相轉化很方便,類型也比較兼容

3.Tensor可以直接通過print顯示數據類型,而Numpy不可以,例如:dtype = torch.float64


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM