torch.linspace,unsqueeze()以及squeeze()函數


1.torch.linspace(start,end,steps=100,dtype)

作用是返回一個一維的tensor(張量),其中dtype是返回的數據類型。

import torch
print(torch.linspace(-1,1,5))

輸出結果為:tensor([-1.0000, -0.5000,  0.0000,  0.5000,  1.0000])

2.unsqueeze()函數

在指定位置增加維度。

import torch
a=torch.arange(0,6)  #a是一維向量
b=a.reshape(2,3)     #b是二維向量
c=b.unsqueeze(1)     #c是三維向量,在b的第二維上增加一個維度
print(a)
print(b)    
print(c)
print(c.size())

a的維度為1x6

b的維度為2x3

b的維度為2x1x3

若想在倒數第二個維度增加一個維度,則c=b.unsqueeze(-1)

3.squeeze()函數

可去掉維度為1的維度。

import torch
a=torch.arange(0,6)  #a是一維向量
b=a.reshape(2,3)
c=b.unsqueeze(1)
print(c)
print(c.size())
d=c.squeeze(1)
print(d)
print(d.size())

輸出結果為:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM