Pytorch中的掩码:dtype=torch.uint8


在pytorch中,dtype=uint8的数据类型往往可以用作掩码0表示舍弃对应项1表示选取对应项。通过设置不同的0或1的值,对另外的tensor进行选择性选取:

例如:

t = torch.rand(42) """ tensor([[0.5492, 0.2083],  [0.3635, 0.5198],  [0.8294, 0.9869],  [0.2987, 0.0279]]) """

# 注意以下mask数据类型是 uint8 mask = torch.ones(4,dtype=torch.uint8) mask[2] = 0 print(mask) print(t[mask, :]) # 选取tensor t的第一个维度(由mask所在的位置决定的,这是numpy的花式索引的知识)的第0,1,3个行,以及这三个行对应的所有列;舍弃t的第2行。 """ tensor([1, 1, 0, 1], dtype=torch.uint8) # 因为是uint8类型的,所以当它被另外一个tensor当作索引时,1代表选取,0代表舍弃。uint8只能是0或者1 tensor([[0.5492, 0.2083], [0.3635, 0.5198], [0.2987, 0.0279]]) """ # 注意, 以下数据类型是long,可以和上面的uint8做一下对比 idx = torch.ones(3,dtype=torch.long) idx[1] = 0 print(idx) print(t[idx, :]) """ tensor([1, 0, 1]) # 因为是long类型的,所以当它被另外一个tensor当作索引时,1代表选取对应维度的第一个,0代表选取维度的第0个。当然也可以是其他的整数。 tensor([[0.3635, 0.5198], [0.5492, 0.2083], [0.3635, 0.5198]]) """
再举个例子:当mask掩码的形状和另一个tensor idx的形状相同,mask作为索引的时候,0或1直接相当于舍去或选取idx相应位置的值

 

 

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM