mIoU混淆矩陣生成函數代碼詳解


代碼參考博客原文: https://blog.csdn.net/jiongnima/article/details/84750819

在原文和原文的引用里,找到了關於mIoU詳盡的解釋。這里重點解析 fast_hist(a, b, n) 這個函數的代碼。

生成混淆矩陣的代碼: 

1 #設標簽寬W,長H
2 def fast_hist(a, b, n):#a是轉化成一維數組的標簽,形狀(H×W,);b是轉化成一維數組的標簽,形狀(H×W,);n是類別數目,實數(在這里為19)
3     '''
4     核心代碼
5     '''
6     k = (a >= 0) & (a < n)#k是一個一維bool數組,形狀(H×W,);目的是找出標簽中需要計算的類別(去掉了背景)
7     return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)#np.bincount計算了從0到n**2-1這n**2個數中每個數出現的次數,返回值形狀(n, n)

在調用了 k = (a >= 0) & (a < n) 以后,得到了bool數組,那它長什么樣子呢?舉個栗子說明:

構造一個4×4的數組a,把背景值設置為255,除背景外類別共3個,分別為1, 2, 3

mushroomer@mushroomerMate:~$ python3
Python 3.7.1rc2 (default, Jun 14 2019, 23:23:01) 
[GCC 5.4.0 20160609] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> a = np.array([[255, 0, 0, 255], [255, 255, 2, 2], [1, 1, 1, 255], [255, 255, 255, 255]])>>> a
array([[255,   0,   0, 255],
       [255, 255,   2,   2],
       [  1,   1,   1, 255],
       [255, 255, 255, 255]])
>>> n = 3
>>> k = (a >= 0) & (a < n)
>>> k
array([[False,  True,  True, False],
       [False, False,  True,  True],
       [ True,  True,  True, False],
       [False, False, False, False]])
>>> a[k]
array([0, 0, 2, 2, 1, 1, 1])

可以看出,k是個和a尺寸相同的bool數組,有效類別都標記為True,背景全部標記為False

a[k] 會把 k 標記的 True 對應在 a 中的值都提取出來。

 

再以 n = 3 為例,混淆矩陣如下:

混淆矩陣映射關系:

$index=n*class(a)+class(b)$

之后是np.bincount, 這個函數統計下標在目標列表中出現的次數。例如:

Python 3.7.1 (default, Dec 10 2018, 22:54:23) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> np.bincount([0, 0, 0, 2, 1, 1, 3])
array([3, 2, 1, 1], dtype=int64)
>>> np.bincount([0, 0, 0, 2, 1, 1, 3], minlength=7)
array([3, 2, 1, 1, 0, 0, 0], dtype=int64)
>>> np.bincount([0, 0, 0, 2, 1, 1, 9], minlength=7)
array([3, 2, 1, 0, 0, 0, 0, 0, 0, 1], dtype=int64)

列表中最大值為3,統計 [0, 1, 2, 3] 對應每個元素在輸入列表中出現的次數,得到 [3, 2, 1, 1], 含義是:0出現3次;1出現2次;2出現1次;3出現1次。

如果指定 minlength, 則認為列表中最大值為 max_value = max(max([0, 0, 0, 2, 1, 1, 3]), minlength),然后去統計 list(range(max_value)) 對應每個元素在輸入列表中出現的次數。

 

在 fast_hist 函數中指定 minlength = n ** 2, 目的是使輸出長度為 n ** 2, 輸出形狀就正好可以轉換為 n * n 矩陣。當然根據 np.bincount 函數的特性,類別值如果超過 minlength,輸出長度就不是 n ** 2 了,因此我舉的栗子里背景值為 255 顯然是不合適的,^_^,意識到了嗎?

然后統計出來混淆矩陣每個 index 對應的 (class a 重疊 class b) 出現的次數,就得到了結果。這里的映射關系重點是要理解每個 index 都對應唯一一個 class a 重疊 class b,例如 n = 3, class a = 1, class b = 2,那么對應的 index = 3*1 + 2 = 5,對應填到混淆矩陣里。假如 class a = 2, class b = 1, 那 index = 3*2 + 1 = 7,index 就變成了7,這個 index 是一一對應的。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM