pytorch常用函數總結
torch.max(input,dim)
求取指定維度上的最大值,,返回輸入張量給定維度上每行的最大值,並同時返回每個最大值的位置索引。
比如:
demo.shape
Out[7]: torch.Size([10, 3, 10, 10])
torch.max(demo,1)[0].shape
Out[8]: torch.Size([10, 10, 10])
torch.max(demo,1)[0]
這其中的[0]取得就是返回的最大值,torch.max(demo,1)[1]
就是返回的最大值對應的位置索引。例子如下:
a
Out[8]:
tensor([[1., 2., 3.],
[4., 5., 6.]])
a.max(1)
Out[9]:
torch.return_types.max(
values=tensor([3., 6.]),
indices=tensor([2, 2]))
class torch.nn.ParameterList(parameters=None)
將submodules
保存在一個list
中。
ParameterList
可以像一般的Python list
一樣被索引
。而且ParameterList
中包含的parameters
已經被正確的注冊,對所有的module method
可見。
參數說明:
- modules (list, optional) – a list of nn.Parameter
例子:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x
torch.cat()函數
cat是concatnate的意思:拼接,聯系在一起。
先說cat( )的普通用法
如果我們有兩個tensor是A和B,想把他們拼接在一起,需要如下操作:
C = torch.cat( (A,B),0 ) #按維數0拼接(豎着拼)
C = torch.cat( (A,B),1 ) #按維數1拼接(橫着拼)
相當於將tensor按照指定維度進行拼接,比如A的shape為128*64*32*32
,B的shape為 128*32*64*64
,那么按照 torch.cat( (A,B),1)
拼接的之后的形狀為 128*96*64*64
。
注意:
兩個tensor要想進行拼接,必須保證除了指定拼接的維度以外其他的維度形狀必須相同,比如上面的例子,拼接A和B時,A的形狀為128*64*32*32
,B的形狀為128*32*64*64
,只有第二個維度的維數數值不同,其他的維度的維數都是相同的,所以拼接時可按維度1進行拼接(注意,維度的下標是從0開始的,比如 A 的形狀對應的維度下標為:\(128_0*64_1*32_2*32_3\))
contiguous()函數的使用
contiguous一般與transpose,permute,view搭配使用:使用transpose或permute進行維度變換后,調用contiguous,然后方可使用view對維度進行變形(如:tensor_var.contiguous().view() ),示例如下:
x = torch.Tensor(2,3)
y = x.permute(1,0) # permute:二維tensor的維度變換,此處功能相當於轉置transpose
y.view(-1) # 報錯,view使用前需調用contiguous()函數
y = x.permute(1,0).contiguous()
y.view(-1) # OK
具體原因有兩種說法:
1 transpose、permute等維度變換操作后,tensor在內存中不再是連續存儲的,而view操作要求tensor的內存連續存儲,所以需要contiguous來返回一個contiguous copy;
2 維度變換后的變量是之前變量的淺拷貝,指向同一區域,即view操作會連帶原來的變量一同變形,這是不合法的,所以也會報錯;---- 這個解釋有部分道理,也即contiguous返回了tensor的深拷貝contiguous copy數據;
原文鏈接:https://zhuanlan.zhihu.com/p/64376950
tensor.repeat()函數
該函數傳入的參數個數不少於tensor的維數,其中每個參數代表的是對該維度重復多少次,也就相當於復制的倍數,結合例子更好理解,如下:
>>> import torch
>>>
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>>
>>> a.repeat(1, 1).size()
torch.Size([33, 55])
>>>
>>> a.repeat(2,1).size()
torch.Size([66, 55])
>>>
>>> a.repeat(1,2).size()
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size()
torch.Size([1, 1, 33, 55])
>>>
>>> # repeat()的參數的個數,不能少於被操作的張量的維度的個數,
>>> # 下面是一些錯誤示例
>>> a.repeat(2).size() # 1D < 2D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>>
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>
參考博客:https://blog.csdn.net/qq_29695701/article/details/89763168
torch.masked_select()函數
a = torch.Tensor([[4,5,7], [3,9,8],[2,3,4]])
b = torch.Tensor([[1,1,0], [0,0,1],[1,0,1]]).type(torch.ByteTensor)
c = torch.masked_select(a,b)
print(c)
用法:torch.masked_select(x, mask),mask必須轉化成torch.ByteTensor類型。
torch.sort
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
對輸入張量input
沿着指定維按升序排序
。如果不給定dim
,則默認為輸入的最后一維。如果指定參數descending
為True
,則按降序排序
返回元組 (sorted_tensor, sorted_indices) , sorted_indices
為原始輸入中的下標。
參數:
- input (Tensor) – 要對比的張量
- dim (int, optional) – 沿着此維排序
- descending (bool, optional) – 布爾值,控制升降排序
- out (tuple, optional) – 輸出張量。必須為
ByteTensor
或者與第一個參數tensor
相同類型。
例子:
>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
-1.6747 0.0610 0.1190 1.4137
-1.4782 0.7159 1.0341 1.3678
-0.3324 -0.0782 0.3518 0.4763
[torch.FloatTensor of size 3x4]
>>> indices
0 1 3 2
2 1 0 3
3 1 0 2
[torch.LongTensor of size 3x4]
>>> sorted, indices = torch.sort(x, 0)
>>> sorted
-1.6747 -0.0782 -1.4782 -0.3324
0.3518 0.0610 0.4763 0.1190
1.0341 0.7159 1.4137 1.3678
[torch.FloatTensor of size 3x4]
>>> indices
0 2 1 2
2 0 2 0
1 1 0 1
[torch.LongTensor of size 3x4]