Pytorch:Tensor 張量操作


張量操作

一、張量的拼接與切分

1.1 torch.cat()

功能:將張量按維度dim進行拼接

tensors:張量序列

dim:要拼接的維度

1.2 torch.stack()

功能:在新創建的維度的上進行拼接

tensors:張量序列

dim:要拼接的維度(如果dim為新的維度,則新增一個維度進行拼接,新維度只能高一維)

         

       

 

1.3 torch.chunk()

功能:將張量按維度進行平均切分

返回值:張量列表

注意事項:若不能整除,最后一份小於其他張量;整除時令商為向上取整的數,如7/3=2.333,取整為3

input:要切分的張量

chunks:要切分的份數

dim:要切分的維度

 

將張量a在第一維上的數據分成三份

 

運行​​​​結果

 

1.4 torch.split()

功能:將張量按維度進行平均切分

返回值:張量列表

input:要切分的張量

split_size_or_sections:為int時,表示每一份的長度;為list時,按list元素切分(注意list的各元素之和需等於維度上的長度

dim:要切分的維度

 

 

二、張量索引

2.1 torch.index_select()

功能:在維度dim上,按index索引數據

返回值:依index索引數據拼接的張量

input:要索引的張量

dim:要索引的維度

index:要索引數據的序號(注意index的數據類型要為torch.long,float會報錯)

 

2.2 torch.masked_select()

功能:按mask中的True進行索引

返回值:一維張量

input:要索引的張量

mask:與input同形狀的布爾類型張量(mask的生成可以通過比較大小關系得出,le為小於等於,詳見圖英文注釋)

 

 三、張量變換

3.1 torch.reshape()

功能:變換張量形狀

注意事項:當張量在內存中是連續時,新張量與input共享數據內存

input:要變換的張量

size:新張量的形狀(形狀中若有-1,則該處的值有其他維數及總數來計算得出)

 

3.2 torch.transpose()

功能:交換兩個張量的維度

input:要交換的張量

dim0:要交換的維度

dim1:要交換的維度

3.3 torch.t()

功能:2維張量轉置,對矩陣而言,等價於torch.transpose(input,0,1)

 

3.4 torch.squeeze()

功能:壓縮長度為1的維度(軸)

dim:若為None,移除所有長度為1的軸;如果指定維度,當且僅當該軸長度為1時,可以被移除

3.5 torch.unsqueeze()

功能:依據dim擴展維度

dim:擴展的維度

 

三、張量數學運算

主要可分為三類:

1.加減乘除   2. 對數、指數、冪函數  3.三角函數

其中加法比較特殊

torch.add()

功能:逐元素計算該式 input+alpha*other(為了簡便於梯度下降的運算)

input:第一個張量

alpha:乘項因子

other:第二個張量

 

另外的拓展還有


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM