張量操作
一、張量的拼接與切分
1.1 torch.cat()
功能:將張量按維度dim進行拼接
tensors:張量序列
dim:要拼接的維度
1.2 torch.stack()
功能:在新創建的維度的上進行拼接
tensors:張量序列
dim:要拼接的維度(如果dim為新的維度,則新增一個維度進行拼接,新維度只能高一維)
1.3 torch.chunk()
功能:將張量按維度進行平均切分
返回值:張量列表
注意事項:若不能整除,最后一份小於其他張量;整除時令商為向上取整的數,如7/3=2.333,取整為3
input:要切分的張量
chunks:要切分的份數
dim:要切分的維度


1.4 torch.split()
功能:將張量按維度進行平均切分
返回值:張量列表
input:要切分的張量
split_size_or_sections:為int時,表示每一份的長度;為list時,按list元素切分(注意list的各元素之和需等於維度上的長度)
dim:要切分的維度
二、張量索引
2.1 torch.index_select()
功能:在維度dim上,按index索引數據
返回值:依index索引數據拼接的張量
input:要索引的張量
dim:要索引的維度
index:要索引數據的序號(注意index的數據類型要為torch.long,float會報錯)
2.2 torch.masked_select()
功能:按mask中的True進行索引
返回值:一維張量
input:要索引的張量
mask:與input同形狀的布爾類型張量(mask的生成可以通過比較大小關系得出,le為小於等於,詳見圖英文注釋)
三、張量變換
3.1 torch.reshape()
功能:變換張量形狀
注意事項:當張量在內存中是連續時,新張量與input共享數據內存
input:要變換的張量
size:新張量的形狀(形狀中若有-1,則該處的值有其他維數及總數來計算得出)
3.2 torch.transpose()
功能:交換兩個張量的維度
input:要交換的張量
dim0:要交換的維度
dim1:要交換的維度
3.3 torch.t()
功能:2維張量轉置,對矩陣而言,等價於torch.transpose(input,0,1)
3.4 torch.squeeze()
功能:壓縮長度為1的維度(軸)
dim:若為None,移除所有長度為1的軸;如果指定維度,當且僅當該軸長度為1時,可以被移除
3.5 torch.unsqueeze()
功能:依據dim擴展維度
dim:擴展的維度
三、張量數學運算
主要可分為三類:
1.加減乘除 2. 對數、指數、冪函數 3.三角函數
其中加法比較特殊
torch.add()
功能:逐元素計算該式 input+alpha*other(為了簡便於梯度下降的運算)
input:第一個張量
alpha:乘項因子
other:第二個張量
另外的拓展還有