Pytorch系列:(一)常用基礎操作


各種張量初始化

創建特殊類型的tensor

a = torch.FloatTensor(2,3) 
a = torch.DoubleTensor(2,3)
...  

設置pytorch中tensor的默認類型

torch.set_default_tensor_type(torch.DoubleTensor)

更改tensor類型

a.float()

各種常用初始化

torch.randn_like()

torch.rand(3,3)   #創建 0-1 (3,3)矩陣

torch.randn(3,3)  #創建 -1-1 (3,3)矩陣

torch.randint(1,10,[2,2])  #創建 1-10 (2,2) int型矩陣

按照不同的均值和方差進行初始化

torch.normal(mean=torch.full([20],0),std=torch.arange(0,1,0.1))

按照間隔初始化

torch.linspace(0,10,step=3)

torch.arange(1,10,5)

創建單位矩陣

torch.eye(4,4)

創建打亂的數列

torch.randperm(10)

返回tensor元素個數

torch.numel(torch.rand(2,2))

維度操作

矩陣拼接

torch.cat((x,x),0)
torch.stack((x,x),0)   #與cat不同的是,stack在拼接的時候,要增加一個維度

矩陣拆分

chuck直接按照數量來拆分,輸入N就拆分成N個

torch.chunk(a,N,dim) 

split的兩種用法,第一種是輸入一個數字,這樣就會拆分成這個總維度/數字個維度,第二個是如輸入一個列表,會按照列表指定的維度進行拆分

torch.split(a,[1,2],dim)

矩陣選取

在某個維度上選擇連續的N 列或者行

torch.narrow(dim,index,size)

選擇一個維度dim,從index開始取size個列或者行

a.index_select(dim, list)

各種選取

a[ : , 1:10,  ::2 , 1:10:2]

矩陣打平后選取

torch.take( tensor , list)

維度變化

a.view(1,5)
a.reshape(1,5) 

維度減少和增加

只有一個維度的時候,就是0在前面插入,-1或1在后面插入,可以把list當成是0.5維度

a.unsqueeze(1)
a.squeeze(1) 

維度擴張

a.expand()  

維度擴展expand,注意這里的維度只能由1擴張成N,其他情況下是不能擴張的,另外維度不變的時候也可以用-1代替

a.repead()  

另外一種方式是使用repeat函數,repeat表示將之前的維度復制多少次,通過復制來進行擴張

維度交換

transpose(2,3)  # 交換兩個維度
permute(4,2,1,3) # 交換多個維度

數學運算

基礎運算

其中加減除法都可以使用運算符直接計算,乘法需要額外注意兩種不同的乘法,其中:

mul或者*是矩陣對應元素相乘

mm是針對於二維的矩陣正常乘法

matmul是針對任意維度矩陣的正常乘法,@是其符號重載

數字近似

floor() 向下取整

ceil() 向上取整

trunc() 保留整數

frac() 保留小數

數值裁剪

clamp(min)

clamp(min,max) #在這個閾值之外的都變成閾值

累乘

prod()

線性代數相關

trace           #矩陣的跡

diag            #獲取主對角線元素

triu/tril       #獲取上下三角矩陣

t               #轉置

dot/cross       #內積與外積

其他

Numpy Tensor 互相轉換

np_data = np.arange(6).reshape((2, 3))
torch_data = torch.from_numpy(np_data)
tensor2array = torch_data.numpy()

類型判斷

isinstance(a,torch.FloatTensor)

廣播

什么時候可以使用廣播,廣播將從最后一個維度開始,從后往前開始匹配,當一個對象的維度是1或者與另一個對象的維度大小一樣的時候,可以匹配上,另外,如果一個對象的維度少於另外一個維度的對象,只要從后往前開始的維度匹配,那么就可以使用廣播。

例如

(1,2,3,4) 和 (2,3,4) or (1,2,3,4) 可以廣播

(1,2,3,4) 和 (1,1,1) or (1,1,1,1) 可以廣播

topk

topk可以幫助返回在某一維度上最大的k個值以及下標,只需要將largest=False,就可以返回最小的k個值

where條件選擇

根據條件是否成立,選擇矩陣X或者矩陣Y中的元素

where(condition > 0.5 , X , Y )  

gather

本質就是在查表,第一個參數是表格,第二個是維度,第三個是要查詢的索引

操作就是,在inpu中選擇維度dim,然后根據index編號,讀取input中的元素

torch.gather(input,dim,index,out=None) 
 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM