PIL_im = Image.fromarray(np.uint8(img))
PIL_att = Image.fromarray(np.uint8(one_map)).convert('RGB')
運行一部分之后報錯
TypeError: Cannot handle this data type: (1, 1, 6), |u1
第一個解決方案
PIL需要的格式是(W,H,C),而數據集的格式是(C,W,H)所以要進行轉換,把(C,W,H)變為(W,H,C)
把最上邊兩行代碼改成
PIL_im = Image.fromarray(np.uint8(img.transpose(1,2,0)))
PIL_att = Image.fromarray(np.uint8(one_map.transpose(1,2,0))).convert('RGB')
這個方法我的程序多運行了一步。。希望的花花就這么沒有了
第二個方案
有的人說python2.7版本解決了這個問題,我換成2.7版本運行的確是解決了這個問題,但是會出現滿屏幕警告,影響運行速度
警告:UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead.
解決方案:(原文:https://github.com/microsoft/IRNet/issues/5)
change the return of
from
length_array_to_mask_tensor()/table_dict_to_mask_tensor()/ pred_col_mask()
from
mask = torch.ByteTensor(mask) to mask = torch.BoolTensor(mask).
