Pytorch-索引與切片


引言

本篇介紹Pytorch 的索引與切片

索引

1
2
3
4
5
6
7
In[3]: a = torch.rand(4,3,28,28)
In[4]: a[0].shape # 理解上相當於取第一張圖片
Out[4]: torch.Size([3, 28, 28])
In[5]: a[0,0].shape # 第0張圖片的第0個通道
Out[5]: torch.Size([28, 28])
In[6]: a[0,0,2,4] # 第0張圖片第0個通道的第2行第4列的像素點 標量
Out[6]: tensor(0.4133) # 沒有用 [] 包起來就是一個標量 dim為0

切片

  • 顧頭不顧尾
1
2
3
4
5
6
7
8
9
10
In[7]: a.shape
Out[7]: torch.Size([4, 3, 28, 28])
In[8]: a[:2].shape # 前面兩張圖片的所有數據
Out[8]: torch.Size([2, 3, 28, 28])
In[9]: a[:2,:1,:,:].shape # 前面兩張圖片的第0通道的數據
Out[9]: torch.Size([2, 1, 28, 28])
In[11]: a[:2,1:,:,:].shape # 前面兩張圖片,第1,2通道的數據
Out[11]: torch.Size([2, 2, 28, 28])
In[10]: a[:2,-1:,:,:].shape # 前面兩張圖片,最后一個通道的數據 從-1到最末尾,就是它本身。
Out[10]: torch.Size([2, 1, 28, 28])

步長

  • 顧頭不顧尾 + 步長
  • start : end : step
  • 對於步長為1的,通常就省略了。
1
2
3
4
a[:,:,0:28,0:28:2].shape    # 隔點采樣
Out[12]: torch.Size([4, 3, 28, 14])
a[:,:,::2,::2].shape
Out[14]: torch.Size([4, 3, 14, 14])

具體的索引

  • .index_select(dim, indices)
    • dim為維度,indices是索引序號
    • 這里的indeces必須是tensor ,不能直接是一個list
1
2
3
4
5
6
7
8
9
10
In[17]: a.shape
Out[17]: torch.Size([4, 3, 28, 28])
In[19]: a.index_select(0, torch.tensor([0,2])).shape # 當前維度為0,取第0,2張圖片
Out[19]: torch.Size([2, 3, 28, 28])
In[20]: a.index_select(1, torch.tensor([1,2])).shape # 當前維度為1,取第1,2個通道
Out[20]: torch.Size([4, 2, 28, 28])
In[21]: a.index_select(2,torch.arange(28)).shape # 第二個參數,只是告訴你取28行
Out[21]: torch.Size([4, 3, 28, 28])
In[22]: a.index_select(2, torch.arange(8)).shape # 取8行 [0,8)
Out[22]: torch.Size([4, 3, 8, 28])

...

  • ... 表示任意多維度,根據實際的shape來推斷。
  • 當有 ... 出現時,右邊的索引理解為最右邊
  • 為什么會有它,沒有它的話,存在這樣一種情況 a[0,: ,: ,: ,: ,: ,: ,: ,: ,: ,2] 只對最后一個維度做了限度,這個向量的維度又很高,以前的方式就不太方便了。
1
2
3
4
5
6
7
8
9
10
In[23]: a.shape
Out[23]: torch.Size([4, 3, 28, 28])
In[24]: a[...].shape # 所有維度
Out[24]: torch.Size([4, 3, 28, 28])
In[25]: a[0,...].shape # 后面都有,取第0個圖片 = a[0]
Out[25]: torch.Size([3, 28, 28])
In[26]: a[:,1,...].shape
Out[26]: torch.Size([4, 28, 28])
In[27]: a[...,:2].shape # 當有...出現時,右邊的索引理解為最右邊,只取兩列
Out[27]: torch.Size([4, 3, 28, 2])

使用mask來索引

  • .masked_select()
  • 求掩碼位置原來的元素大小
  • 缺點:會把數據,默認打平(flatten),
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
In[31]: x = torch.randn(3,4)
In[32]: x
Out[32]:
tensor([[ 2.0373, 0.1586, 0.1093, -0.6493],
[ 0.0466, 0.0562, -0.7088, -0.9499],
[-1.2606, 0.6300, -1.6374, -1.6495]])
In[33]: mask = x.ge(0.5) # >= 0.5 的元素的位置上為1,其余地方為0
In[34]: mask
Out[34]:
tensor([[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 1, 0, 0]], dtype=torch.uint8)
In[35]: torch.masked_select(x,mask)
Out[35]: tensor([2.0373, 0.6300]) # 之所以打平是因為大於0.5的元素個數是根據內容才能確定的
In[36]: torch.masked_select(x,mask).shape
Out[36]: torch.Size([2])

使用打平(flatten)后的序列

  • torch.take(src, torch.tensor([index]))
  • 打平后,按照index來取對應位置的元素
1
2
3
4
5
6
7
In[39]: src = torch.tensor([[4,3,5],[6,7,8]])		# 先打平成1維的,共6列
In[40]: src
Out[40]:
tensor([[4, 3, 5],
[6, 7, 8]])
In[41]: torch.take(src, torch.tensor([0, 2, 5])) # 取打平后編碼,位置為0 2 5
Out[41]: tensor([4, 5, 8])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM