一、普通索引
示例
a = t.Tensor(4,5) print(a) print(a[0:1,:2]) print(a[0,:2]) # 注意和前一種索引出來的值相同,shape不同 print(a[[1,2]]) # 容器索引
3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3418e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 0.0000e+00 0.0000e+00 1.5035e+38 8.5479e-43 1.5134e-43 1.2612e-41 [torch.FloatTensor of size 4x5] 3.3845e+15 0.0000e+00 [torch.FloatTensor of size 1x2] 3.3845e+15 0.0000e+00 [torch.FloatTensor of size 2] 0.0000e+00 3.3845e+15 0.0000e+00 3.3418e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 0.0000e+00 [torch.FloatTensor of size 2x5]
普通索引內存分析
普通索引后的結果和原Tensor的內存共享
print(a[a>1]) import copy b = copy.deepcopy(a) a[a>1]=10 print(a,b)
3.3845e+15 3.3846e+15 3.3845e+15 3.3845e+15 3.3418e+15 3.3845e+15 3.3846e+15 1.5035e+38 [torch.FloatTensor of size 8] 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 10.0000 0.0000 0.0000 0.0000 10.0000 0.0000 0.0000 0.0000 [torch.FloatTensor of size 4x5] 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3418e+15 0.0000e+00 3.3845e+15 0.0000e+00 3.3846e+15 0.0000e+00 0.0000e+00 0.0000e+00 1.5035e+38 8.5479e-43 1.5134e-43 1.2612e-41 [torch.FloatTensor of size 4x5]
索引函數gather介紹
方的介紹:
如果input是一個n維的tensor,size為 (x0,x1…,xi−1,xi,xi+1,…,xn−1),dim為i,然后index必須也為n維tensor,size為 (x0,x1,…,xi−1,y,xi+1,…,xn−1),其中y >= 1,最后輸出的out與index的size是一樣的。
意思就是按照一個指定的軸(維數)收集值
對於一個三維向量來說:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
參數:
input (Tensor) – 源tensor
dim (int) – 指定的軸數(維數)
index (LongTensor) – 需要聚集起來的數據的索引
out (Tensor, optional) – 目標tensor
簡單來說,就是在Tensor(input)的眾多維度中針對某一維度(dim參數),使用一維Tensor(index)進行索引,並對其他維度進行遍歷。
a = t.arange(16).view(4,4) index = t.LongTensor([[0,1,2,3]]) print(a) print(index) print(a.gather(0,index)) # 逆操作scatter_,注意是inplace的 b = t.zeros(4,4) b.scatter_(0,index,a.gather(0,index)) print(b)
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 [torch.FloatTensor of size 4x4] 0 1 2 3 [torch.LongTensor of size 1x4] 0 5 10 15 [torch.FloatTensor of size 1x4] 0 0 0 0 0 5 0 0 0 0 10 0 0 0 0 15 [torch.FloatTensor of size 4x4]
二、高階索引
和普通索引不同,高階索引前后一般不會共享內存,后面介紹Tensor內存結構時會提到。
x = t.arange(0,27).view(3,3,3) print(x) print(x[[1,2],[1,2],[2,0]]) # x[1,1,2]和x[2,2,0] print(x[[2,1,0],[0],[0]]) # x[2,0,0]和x[1,0,0]和x[0,0,0]
(0 ,.,.) = 0 1 2 3 4 5 6 7 8 (1 ,.,.) = 9 10 11 12 13 14 15 16 17 (2 ,.,.) = 18 19 20 21 22 23 24 25 26 [torch.FloatTensor of size 3x3x3] 14 24 [torch.FloatTensor of size 2] 18 9 0 [torch.FloatTensor of size 3]