@
index索引
torch會自動從左向右索引
例子:
a = torch.randn(4,3,28,28)
表示類似一個CNN 的圖片的輸入數據,4表示這個batch一共有4張照片,而3表示圖片的通道數為3(RGB),(28,28)表示圖片的大小
基本索引
索引1:表示第零張圖片的shape
print(a[0].shape)
#torch.Size([3,28,28])
索引2:第零張圖片的第零個通道的size
print(a[0,0].shape)
#torch.Size([28,28])
索引3:表示第零張圖片的第零個通道的第二行第四列的像素點的值
print(a[0,0,2,4])
#tensor(0.8082)
連續選取
⭐索引4:連續取兩張圖片(取第0張以及第一張圖片,不包括第二張
)
print(a[:2].shape
#torch.Size([2,3,28,28])
#由於是兩張圖片,所以第一維變為2
⭐索引5:前兩張圖片上的第一個通道上的數據(所以通道數變為了1)
print(a[:2,:1,:,:].shape)
print(a[:2,:1].shape)
#torch.Size(2,1,28,28)
⭐索引6:從后面取(-1表示最后一個,從最后一個取到最后,也就是一個通道)
print(a[:2,-1:,:,:].shape)
#torch.Size(2,1,28,28)
規則間隔索引
⭐索引7:在圖片的矩陣進行隔行與隔列索引 0:28:2表示從0到28(不包括28),間隔數為2
print(a[:,:,0:28:2,0:28:2].shape)
print(a[:,:,::2,::2].shape)
#torch.Size([4,3,14,14])
索引總結
start : end : step
:
都取
x:
從x取到最后 :x
從開始取到x x:y
從x取到y
x:y:z
從x到y每隔z個點采樣一次
不規則間隔索引
使用index_select()函數
第一個參數表示你對哪個維度進行操作;第二個參數是index(必須是tensor類型
):對第0張與第2張圖片進行操作
a.index_select(0,torch.tensor([0,2])).shape
#【2,3,28,28】
同理:選擇了兩個通道
a.index_select(1,torch.tensor([1,2])).shape
#【4,2,28,28】
同理:只取8行
a.index_select(2,torch.arange(8)).shape
#【4,2,8,28】
任意多的維度索引
使用符號:...
例子:
a[...].shape
#[4,3,28,28]
a[0,...].shape
#[3,28,28]
a[0,1,...].shape
#[4,28,28]
a[...,2].shape
#[4,3,28,2]
使用掩碼來索引
函數:.masked_select()
會將篩選出來的元素打平(因為無法維護原來的shape)
x = torch.randn(2,3)
print(x)
tensor([[-1.3081, -0.5651, -0.9843],
[ 1.0051, -0.3829, 0.6300]])
mask = x.ge(0.5)#大於等於0.5的元素
print(mask)
tensor([[False, False, False],
[ True, False, True]])
z = torch.masked_select(x,mask)
print(z)
tensor([1.0051, 0.6300])
打平后的索引
例子:使用take函數:是將輸入的tensor打平之后進行index的選擇
src = torch.tensor([[4,3,5],[6,7,8]])
torch.take(src,torch.tensor([0,2,8]))
#tensor([4,5,8])