對於語義分割網絡,其輸出為(b, h, w, classes),對索引求最大值,得到維度為(b, h, w, 1)
相對於得到一個灰度圖,其亮度值為類別index。因為類別值為[1, num_classes], 如果對輸出
結果直接顯示,會的到一副純黑的圖。
所以需要進行預測結果可視化
將預測結果轉化為RGB圖像
首先建立預測類別和相應rgb顏色的映射
Label = namedtuple('Label', [
'name',
'trainId',
'category',
'categoryId',
'hasInstances',
'ignoreInEval',
'color',
])
labels = [
# name id trainId category catId hasInstances ignoreInEval color
Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
Label('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
Label('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
Label('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
Label('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
Label('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
Label('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
Label('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
Label('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
Label('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
Label('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
Label('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
Label('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
Label('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
Label('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
Label('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
Label('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
Label('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
Label('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
Label('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
Label('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
Label('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
Label('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
]
trainId2label = {label.trainId: label for label in reversed(labels)}
// {-1: Label(), 18:Label(),。。。。}
生成一個與原圖大小一樣的三維矩陣
colored_image = np.zeros(
(class_id_image.shape[0], class_id_image.shape[1], 3), np.uint8)
將對應位置填補為類別對應的RGB
for row in range(class_id_image.shape[0]):
for col in range(class_id_image.shape[1]):
try:
colored_image[row, col, :] = class_id_to_rgb_map[
int(class_id_image[row, col])].color
所以全過程為
probs = pspnet.predict(img)
cm = np.argmax(probs, axis=2)
colored_class_image = color_class_image(cm)

alpha_blended = 0.5 * colored_class_image + 0.5 * img
與原圖混合

補充:還可以用PIL內置調色板方法,
new_mask = PIL.Image.fromarray(mask.astype(np.uint8)).convert('P')
new_mask.putpalette(palette)
from PIL import Image
Image.open('PennFudanPed/PNGImages/FudanPed00001.png')
mask = Image.open('PennFudanPed/PedMasks/FudanPed00001_mask.png')
mask.putpalette([
0, 0, 0, # black background
255, 0, 0, # index 1 is red
255, 255, 0, # index 2 is yellow
255, 153, 0, # index 3 is orange
])
