pytorch tensor的索引與切片


tensor索引與numpy類似,支持冒號,和數字直接索引

import torch

a = torch.Tensor(2, 3, 4)
a
# 輸出:
      tensor([[[9.2755e-39, 1.0561e-38, 9.7347e-39, 1.1112e-38],
             [1.0194e-38, 8.4490e-39, 1.0102e-38, 9.0919e-39],
             [1.0102e-38, 8.9082e-39, 8.4489e-39, 1.0102e-38]],
    
            [[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
             [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
             [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]]])

# 冒號索引與數字索引
a[:1, :2, 1]
# 輸出:
      tensor([[1.0561e-38, 8.4490e-39]])

# 通過-1索引
a[-1]
# 輸出:
      tensor([[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
            [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
            [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]])

...(三個點)索引

用於維度過多,且取中間多個維度所有數據的情況

# 生成多維數據
a = torch.rand(1,2,3,2,4,5)
a
# 輸出:
     tensor([[[[[[0.1954, 0.1918, 0.3053, 0.3649, 0.3637],
                [0.8467, 0.0205, 0.2187, 0.8438, 0.1754],
                [0.7076, 0.7047, 0.1852, 0.5374, 0.7024],
                [0.5630, 0.4526, 0.0662, 0.9463, 0.9294]],
    
               [[0.6917, 0.5505, 0.5770, 0.3819, 0.9541],
                [0.8957, 0.2530, 0.4858, 0.1866, 0.2542],
                [0.3745, 0.2125, 0.5537, 0.5642, 0.2284],
                [0.2634, 0.1147, 0.1793, 0.0277, 0.9800]]], 

              ...

              [[[0.9949, 0.2210, 0.3365, 0.0852, 0.4387],
                [0.6440, 0.6391, 0.9141, 0.2288, 0.6203],
                [0.0474, 0.7894, 0.4362, 0.9752, 0.7546],
                [0.1234, 0.0246, 0.1436, 0.0053, 0.3405]],
    
               [[0.8174, 0.9021, 0.0420, 0.2045, 0.2140],
                [0.4844, 0.6342, 0.2965, 0.9299, 0.2284],
                [0.1420, 0.1834, 0.0581, 0.8467, 0.8987],
                [0.8012, 0.1526, 0.4293, 0.3928, 0.5437]]]]]]) 

# 取第一維和最后一維的0索引數據,中間所有維度數據全部取出
a[0, ..., 0]
# 輸出:
      tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
              [0.6917, 0.8957, 0.3745, 0.2634]],
    
             [[0.4374, 0.0534, 0.6809, 0.7086],
              [0.2231, 0.6680, 0.8643, 0.9057]],
    
             [[0.8169, 0.0649, 0.5923, 0.3802],
              [0.2562, 0.0095, 0.8557, 0.6828]]],
    
    
            [[[0.1514, 0.3948, 0.6452, 0.6332],
              [0.8872, 0.7304, 0.6853, 0.9814]],
    
             [[0.5736, 0.5195, 0.9711, 0.5575],
              [0.6778, 0.9334, 0.5647, 0.1006]],
    
             [[0.9949, 0.6440, 0.0474, 0.1234],
              [0.8174, 0.4844, 0.1420, 0.8012]]]])

# 上面等價於
a[0,:,:,:,:,0]
# 輸出:
      tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
              [0.6917, 0.8957, 0.3745, 0.2634]],
    
             [[0.4374, 0.0534, 0.6809, 0.7086],
              [0.2231, 0.6680, 0.8643, 0.9057]],
    
             [[0.8169, 0.0649, 0.5923, 0.3802],
              [0.2562, 0.0095, 0.8557, 0.6828]]],
    
    
            [[[0.1514, 0.3948, 0.6452, 0.6332],
              [0.8872, 0.7304, 0.6853, 0.9814]],
    
             [[0.5736, 0.5195, 0.9711, 0.5575],
              [0.6778, 0.9334, 0.5647, 0.1006]],
    
             [[0.9949, 0.6440, 0.0474, 0.1234],
              [0.8174, 0.4844, 0.1420, 0.8012]]]])
可以看出,使用...可以節省操作。

masked_select

# 生成隨機數據
a = torch.randn(3, 4)
a
# 輸出:
    tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
            [ 0.4734, -0.7182, -0.1516,  0.0209],
            [ 0.5089, -0.8130, -0.4519, -0.6190]])

# 大於0.5的數據返回True
mask = a.ge(0.5)
mask
# 輸出:
    tensor([[ True,  True, False, False],
            [False, False, False, False],
            [ True, False, False, False]])

# 通過上面生成的bool數據,利用masked_select來選擇大於0.5的數據
torch.masked_select(a, mask)
# 輸出:
    tensor([0.8710, 0.8862, 0.5089])  

take

a
# 輸出:
      tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
            [ 0.4734, -0.7182, -0.1516,  0.0209],
            [ 0.5089, -0.8130, -0.4519, -0.6190]])

# 先將數據打平展開為一維,再選取展開后對應索引[0, 5, 8, 11]的數據
torch.take(a, torch.tensor([0, 5, 8, 11]))
# 輸出:
      tensor([ 0.8710, -0.7182,  0.5089, -0.6190])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM