Pytorch中的掩碼:dtype=torch.uint8


在pytorch中,dtype=uint8的數據類型往往可以用作掩碼0表示舍棄對應項1表示選取對應項。通過設置不同的0或1的值,對另外的tensor進行選擇性選取:

例如:

t = torch.rand(42) """ tensor([[0.5492, 0.2083],  [0.3635, 0.5198],  [0.8294, 0.9869],  [0.2987, 0.0279]]) """

# 注意以下mask數據類型是 uint8 mask = torch.ones(4,dtype=torch.uint8) mask[2] = 0 print(mask) print(t[mask, :]) # 選取tensor t的第一個維度(由mask所在的位置決定的,這是numpy的花式索引的知識)的第0,1,3個行,以及這三個行對應的所有列;舍棄t的第2行。 """ tensor([1, 1, 0, 1], dtype=torch.uint8) # 因為是uint8類型的,所以當它被另外一個tensor當作索引時,1代表選取,0代表舍棄。uint8只能是0或者1 tensor([[0.5492, 0.2083], [0.3635, 0.5198], [0.2987, 0.0279]]) """ # 注意, 以下數據類型是long,可以和上面的uint8做一下對比 idx = torch.ones(3,dtype=torch.long) idx[1] = 0 print(idx) print(t[idx, :]) """ tensor([1, 0, 1]) # 因為是long類型的,所以當它被另外一個tensor當作索引時,1代表選取對應維度的第一個,0代表選取維度的第0個。當然也可以是其他的整數。 tensor([[0.3635, 0.5198], [0.5492, 0.2083], [0.3635, 0.5198]]) """
再舉個例子:當mask掩碼的形狀和另一個tensor idx的形狀相同,mask作為索引的時候,0或1直接相當於舍去或選取idx相應位置的值

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM