在pytorch中,dtype=uint8的數據類型往往可以用作掩碼,0表示舍棄對應項,1表示選取對應項。通過設置不同的0或1的值,對另外的tensor進行選擇性選取:
例如:
t = torch.rand(4,2) """ tensor([[0.5492, 0.2083], [0.3635, 0.5198], [0.8294, 0.9869], [0.2987, 0.0279]]) """
# 注意以下mask數據類型是 uint8 mask = torch.ones(4,dtype=torch.uint8) mask[2] = 0 print(mask) print(t[mask, :]) # 選取tensor t的第一個維度(由mask所在的位置決定的,這是numpy的花式索引的知識)的第0,1,3個行,以及這三個行對應的所有列;舍棄t的第2行。 """ tensor([1, 1, 0, 1], dtype=torch.uint8) # 因為是uint8類型的,所以當它被另外一個tensor當作索引時,1代表選取,0代表舍棄。uint8只能是0或者1 tensor([[0.5492, 0.2083], [0.3635, 0.5198], [0.2987, 0.0279]]) """ # 注意, 以下數據類型是long,可以和上面的uint8做一下對比 idx = torch.ones(3,dtype=torch.long) idx[1] = 0 print(idx) print(t[idx, :]) """ tensor([1, 0, 1]) # 因為是long類型的,所以當它被另外一個tensor當作索引時,1代表選取對應維度的第一個,0代表選取維度的第0個。當然也可以是其他的整數。 tensor([[0.3635, 0.5198], [0.5492, 0.2083], [0.3635, 0.5198]]) """
再舉個例子:當mask掩碼的形狀和另一個tensor idx的形狀相同,mask作為索引的時候,0或1直接相當於舍去或選取idx相應位置的值
