torch常用的函数


1、torch.cat():是将两个张量(tensor)拼接在一起。

C = torch.cat( (A,B),0 )  #按维数0拼接(行数增加)
C = torch.cat( (A,B),1 )  #按维数1拼接(列数增加)

2、tensor.expand_as():把一个tensor变成和函数括号内一样形状的tensor

3、tensor.narrow(dim,index,number):dim-取行/列;index-从索引index开始取;number-取的行数/列数

4、contiguous():在调用view()之前需要先调用contiguous()

5、tensor.contiguous().view():范数的数据和传入的tensor一样,只是形状不同;注意view()返回的tensor和传入的tensor共享内存,意思就是只要修改其中一个,数据都会变;如果出现-1,-1表示让电脑帮助我们计算,看例子:

import torch a = torch.arange(0,20)      #此时a的shape是(1,20)
a.view(4,5).shape       #输出为(4,5)
a.view(-1,5).shape      #输出为(4,5)
a.view(4,-1).shape      #输出为(4,5)

6、tensor.transpose(0,1):将0维与1为交互,即转置

7、Tensor.uniform_(from=-1, to=1):将tensor用从均匀分布中抽样得到的值填充

a = torch.Tensor(2, 3).uniform_(-1, 1)

8、Tensor.fill_(1):用指定的数填充tensor

a = torch.Tensor(2,3).fill_(1)

9、Tensor.new():创建一个新的Tensor,该Tensor的type和device都和原有的Tensor保持一致,且该Tensor中无内容;

inputs = torch.randn(m, n)
new_inputs = inputs.new()
new_inputs = torch.Tensor.new(inputs)

#实际应用(添加噪声),可以对Tensor添加噪声,添加如下代码即可实现:
noise = inputs.data.new(inputs.size()).normal_(0,0.01)

#Tensor.data:从原有计算中分离(复制)出来的一个tensor变量(不安全)

10、Tensor.squeeze() & Tensor.unsqueeze() & view():

 1 a = torch.arange(0, 9)
 2 print(a.shape)  # torch.Size([9])
 3 
 4 # 1、利用view()改变tensor的形状。值得注意的是view不会修改自身的数据,返回的新tensor与源tensor共享内存;同时必须保证前后元素总数一致。
 5 b = a.view(3, 3)
 6 print(b.shape)  # torch.Size([2, 5])
 7 
 8 # 2、unsqueeze()该函数用来增加某个维度。在PyTorch中维度是从0开始的
 9 # 在第一个维度(即维度序号为0)前增加一个维度。同理,可在其他位置添加维度
10 c = b.unsqueeze(0)
11 print(c.shape)  # torch.Size([1, 3, 3])
12 
13 d = c.unsqueeze(2)
14 print(d.shape)  # torch.Size([1, 3, 3])
15 
16 # 3、squeeze():该函数用来减少某个维度。
17 e = d.squeeze(2)
18 print(e.shape)

11、

 

12、

 

13、

14、

15、

 

 

 

 

 

 

 

 


 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM