pytorch張量數據索引切片與維度變換操作大全(非常全)


(1-1)pytorch張量數據的索引與切片操作
1、對於張量數據的索引操作主要有以下幾種方式:
a=torch.rand(4,3,28,28):DIM=4的張量數據a
(1)a[:2]:取第一個維度的前2個維度數據(不包括2);
(2)a[:2,:1,:,:]:取第一個維度的前兩個數據,取第2個維度的前1個數據,后兩個維度全都取到;
(3)a[:2,1:,:,:]:取第一個維度的前兩個數據,取第2個維度的第1個索引到最后索引的數據(包含1),后兩個維度全都取到;
(4)a[:2,-3:]:負號表示第2個維度上從倒數第3個數據取到最后倒數第一個數據-1(包含-3);
(5)a[:,:,0:28:2,0:28:2]:兩個冒號表示隔行取數據,一定的間隔;
(6)a[:,:,::2,::3]:兩個冒號直接寫表示從所有的數據中隔行取數據。
2、對於tensor數據的切片與其中某些維度數據的提取方法:
a.index_select(x,torch.tensor([m,n])):表示提取tensor數據a的第x個維度上的索引為m和n的數據
3、torch.masked_select(x,mask):該函數主要用來選取x數據中的mask性質的數據,比如mask=x.ge(0.5)表示選出大於0.5的所有數據,並且輸出時將其轉換為了dim=1的打平tensor數據。
4、#take函數的應用:先將張量數據打平為一個dim=1的張量數據(依次排序下來成為一個數據列),然后按照索引進行取數據
a=torch.tensor([[1,2,3],[4,5,6]])
torch.take(a,torch.tensor([1,2,5])):表示提取a這個tensor數據打平以后的索引為1/2/5的數據元素
(1-2)tensor數據的維度變換
1、對於tensor數據的維度變換主要有四大API函數:
(1)view/reshape:主要是在保證tensor數據大小不變的情況下對tensor數據進行形狀的重新定義與轉換
(2)Squeeze/unsqueeze:刪減維度或者增加維度操作
(3)transpose/t/permute類似矩陣的轉置操作,對於多維的數據具有多次或者單次的轉換操作
(4)Expand/repeat:維度的擴展,將低維數據轉換為高維的數據
2、view(reshape)維度轉換操作時需要保證數據的大小numl保持不變,即數據變換前后的prod是相同的:
prod(a.size)=prod(b.size)
另外,對於view操作有一個致命的缺陷就是在數據進行維度轉換之后數據之前的存儲與維度順序信息會丟失掉,不能夠復原,而這對於訓練的數據來說非常重要。
3、squeeze/unsqueeze擠壓和增加維度操作的函數
a=torch.rand(4,3,28,28)
a.unsqueeze(1):在a原來維度索引1之間增加一個維度
a.unsqueeze(-1):在a原來維度索引-1之后增加維度
例如:
a=torch.tensor([1.2,1.3]) #[2]
print(a.unsqueeze(0)) #[1,2]
print(a.unsqueeze(-1)) #[2,1]
a=torch.rand(4,32,28,28)
b=torch.rand(32) #如果要實現a和數據b的疊加,則需要對於數據b進行維度擴張
print(b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)
4、維度刪減squeeze()
對於維度的擠壓squeeze,主要是擠壓掉tensor數據中維度特征數為1的維度,如果不是1的話就不可以擠壓
b=torch.rand(32)
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.squeeze().shape)
print(b.squeeze(0).shape)
print(b.squeeze(1).shape)
print(b.squeeze(-1).shape)
5、維度的擴展:expand(絕對擴展)/repeat(相對擴展)
#維度的擴張expand(絕對值)/repeat,repeat擴展實質是重復拷貝的次數-相對值,並且由於拷貝操作,原來的數據不能再用,已經改變,而expand是絕對擴展,其實現只能從1擴張到n,不能從M擴張到N,另外-1表示對該維度保持不變的操作。
a=torch.rand(4,32,14,14)
b=torch.rand(32)
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(a.shape,b.shape)
print(b.expand(4,32,14,14).shape)
print(b.expand(-1,32,-1,-1).shape) #-1表示對維度保持不變
print(b.repeat(4,32,1,1).shape)
print(b.repeat(4,1,14,14).shape)
6、維度交換操作:
(1).t()操作:只可以對DIM=2的矩陣進行轉置操作
(2)transpose操作:對不同的DIM均可以進行維度交換
a=torch.rand(4,3,32,32)
a1=a.transpose(1,3).contiguous().view(4,32*32*3).view(4,32,32,3).transpose(1,3)
print(a1.shape)
print(torch.all(torch.eq(a,a1)))
整體的變換順序為a[b,c,h,w]->[b,w,h,c]->[b,w*h*c]->[b,w,h,c]->[b,c,h,w]
7、permute操作
相比於transpose只可以進行兩個維度之間的一次交換操作,permute維度交換操作可以一步實現多個維度之間的交換(相當於transpose操作的多步操作)
#.t()和transpose/permute維度交換操作,需要考慮數據的信息保存,不能出現數據的污染和混亂.contiguous()操作保持存儲順序不變
c=torch.rand(3,4)
print(c)
print(c.t())
a=torch.rand(4,3,32,32)
a1=a.transpose(1,3).contiguous().view(4,32*32*3).view(4,32,32,3).transpose(1,3)
print(a1.shape)
print(torch.all(torch.eq(a,a1)))
a=torch.rand(4,3,28,32)
a1=a.permute(0,2,3,1)
print(a1.shape)
a2=a.contiguous().permute(0,2,3,1)
print(torch.all(torch.eq(a1,a2)))

對於以上的數據維度變換和索引切片訓練代碼如下所示:
#tensor數據的索引與切片操作
import torch
a=torch.rand(4,3,28,28)
print(a)
print(a.shape)
print(a.dim())
#索引與切片操作
print(a[0].shape)
print(a[0,0,1,2])
print(a[:2].shape)
print(a[:2,:1,:,:].shape)
print(a[:2,1:,:,:].shape)
print(a[:2,-3:].shape)
print(a[:,:,0:28:2,0:28:2].shape)
print(a[:,:,::2,::3].shape)
#選擇其中某維度的某些索引數據
b=torch.rand(5,3,3)
print(b)
print(b.index_select(0,torch.tensor([1,2,4])))
print(b.index_select(2,torch.arange(2)).shape)
#...操作表示自動判斷其中得到維度區間
a=torch.rand(4,3,28,28)
print(a[...,2].shape)
print(a[0,...,::2].shape)
print(a[...].shape)
#msaked_select
x=torch.randn(3,4)
print(x)
mask=x.ge(0.5) #選出所有元素中大於0.5的數據
print(mask)
print(torch.masked_select(x,mask)) #選出所有元素中大於0.5的數據,並且輸出時將其轉換為了dim=1的打平tensor數據
#take函數的應用:先將張量數據打平為一個dim=1的張量數據(依次排序下來成為一個數據列),然后按照索引進行取數據
a=torch.tensor([[1,2,3],[4,5,6]])
print(a)
print(a.shape)
print(torch.take(a,torch.tensor([1,2,5])))

#tensor數據的維度變換
#view/reshape操作:不進行額外的記住和存貯就會丟失掉原來的數據的數據和維度順序信息,而這是非常重要的
a=torch.rand(4,1,28,28)
print(a.view(4,28*28))
b=a.view(4,28*28)
print(b.shape)
#squeeze/unsqueeze擠壓和增加維度的操作
a=torch.rand(4,3,28,28)
print(a)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(-1).shape)
print(a.unsqueeze(-4).shape)
a=torch.tensor([1.2,1.3]) #[2]
print(a.unsqueeze(0)) #[1,2]
print(a.unsqueeze(-1)) #[2,1]
a=torch.rand(4,32,28,28)
b=torch.rand(32) #如果要實現a和數據b的疊加,則需要對於數據b進行維度擴張
print(b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape)
print(b.squeeze().shape)
print(b.squeeze(0).shape)
print(b.squeeze(1).shape)
print(b.squeeze(-1).shape)
#維度的擴張expand(絕對值)/repeat(重復拷貝的次數-相對值,並且由於拷貝操作,原來的數據不能再用,已經改變),只能從1擴張到n,不能從M擴張到N,另外-1表示對該維度保持不變的操作
a=torch.rand(4,32,14,14)
b=torch.rand(32)
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(a.shape,b.shape)
print(b.expand(4,32,14,14).shape)
print(b.expand(-1,32,-1,-1).shape) #-1表示對維度保持不變
print(b.repeat(4,32,1,1).shape)
print(b.repeat(4,1,14,14).shape)
#.t()和transpose/permute維度交換操作,需要考慮數據的信息保存,不能出現數據的污染和混亂.contiguous()操作保持存儲順序不變
c=torch.rand(3,4)
print(c)
print(c.t())
a=torch.rand(4,3,32,32)
a1=a.transpose(1,3).contiguous().view(4,32*32*3).view(4,32,32,3).transpose(1,3)
print(a1.shape)
print(torch.all(torch.eq(a,a1)))
a=torch.rand(4,3,28,32)
a1=a.permute(0,2,3,1)
print(a1.shape)
a2=a.contiguous().permute(0,2,3,1)
print(torch.all(torch.eq(a1,a2)))

最終的實現結果如下所示:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM