pytorch索引與切片


@

index索引

torch會自動從左向右索引

例子:

a = torch.randn(4,3,28,28)

表示類似一個CNN 的圖片的輸入數據,4表示這個batch一共有4張照片,而3表示圖片的通道數為3(RGB),(28,28)表示圖片的大小

基本索引

索引1:表示第零張圖片的shape

print(a[0].shape)
#torch.Size([3,28,28])

索引2:第零張圖片的第零個通道的size

print(a[0,0].shape)
#torch.Size([28,28])

索引3:表示第零張圖片的第零個通道的第二行第四列的像素點的值

print(a[0,0,2,4])
#tensor(0.8082)

連續選取

⭐索引4:連續取兩張圖片(取第0張以及第一張圖片,不包括第二張)

print(a[:2].shape
#torch.Size([2,3,28,28])
#由於是兩張圖片,所以第一維變為2

⭐索引5:前兩張圖片上的第一個通道上的數據(所以通道數變為了1)

print(a[:2,:1,:,:].shape)
print(a[:2,:1].shape)
#torch.Size(2,1,28,28)

⭐索引6:從后面取(-1表示最后一個,從最后一個取到最后,也就是一個通道)

print(a[:2,-1:,:,:].shape)

#torch.Size(2,1,28,28)

規則間隔索引

⭐索引7:在圖片的矩陣進行隔行與隔列索引 0:28:2表示從0到28(不包括28),間隔數為2

print(a[:,:,0:28:2,0:28:2].shape)
print(a[:,:,::2,::2].shape)
#torch.Size([4,3,14,14])

索引總結

start : end : step

都取

x:從x取到最后 :x 從開始取到x x:y從x取到y

x:y:z從x到y每隔z個點采樣一次

不規則間隔索引

使用index_select()函數

第一個參數表示你對哪個維度進行操作;第二個參數是index(必須是tensor類型):對第0張與第2張圖片進行操作

a.index_select(0,torch.tensor([0,2])).shape
#【2,3,28,28】

同理:選擇了兩個通道

a.index_select(1,torch.tensor([1,2])).shape
#【4,2,28,28】

同理:只取8行

a.index_select(2,torch.arange(8)).shape
#【4,2,8,28】

任意多的維度索引

使用符號:...

例子:

a[...].shape
#[4,3,28,28]

a[0,...].shape
#[3,28,28]

a[0,1,...].shape
#[4,28,28]

a[...,2].shape
#[4,3,28,2]

使用掩碼來索引

函數:.masked_select()會將篩選出來的元素打平(因為無法維護原來的shape)

x = torch.randn(2,3)
print(x)

tensor([[-1.3081, -0.5651, -0.9843],
        [ 1.0051, -0.3829,  0.6300]])

mask = x.ge(0.5)#大於等於0.5的元素
print(mask)

tensor([[False, False, False],
        [ True, False,  True]])

z = torch.masked_select(x,mask)
print(z)

tensor([1.0051, 0.6300])

打平后的索引

例子:使用take函數:是將輸入的tensor打平之后進行index的選擇

src = torch.tensor([[4,3,5],[6,7,8]])
torch.take(src,torch.tensor([0,2,8]))
#tensor([4,5,8])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM