語(yǔ)義分割之MIoU原理與實(shí)現(xiàn)

一、MIoU簡(jiǎn)介

MIoU全稱為Mean Intersection over Union,平均交并比??勺鳛檎Z(yǔ)義分割系統(tǒng)性能的評(píng)價(jià)指標(biāo)。

P:Prediction預(yù)測(cè)值
G:Ground Truth真實(shí)值

MIoU

假設(shè)有3個(gè)類別,各類別iou值如下:

iou_1, iou_2, iou_3 = 0.6, 0.7, 0.75
num_class = 3
miou = sum([iou_1, iou_2, iou_3]) / num_class

二、IoU簡(jiǎn)介

交并比是預(yù)測(cè)值和真實(shí)值的交集和并集之比。


圖取自網(wǎng)絡(luò)

要預(yù)測(cè)的圖像中有人、樹、草等...
我們針對(duì)“人”這個(gè)類別來直觀看下,什么是交并比
左上圖和右上圖,分別是標(biāo)注好的真實(shí)值,和模型輸出的預(yù)測(cè)值。

交集:真實(shí)值和預(yù)測(cè)值的交集就是下面這張圖:
藍(lán)線圈的是真實(shí)值,紅線圈的是預(yù)測(cè)值,黃色區(qū)域就是交集。

并集:真實(shí)值和預(yù)測(cè)值的并集就是下面這張圖:
藍(lán)線圈的是真實(shí)值,紅線圈的是預(yù)測(cè)值,黃色區(qū)域就是并集

混淆矩陣(Confusion Matrix)

混淆矩陣可以用于直觀展示每個(gè)類別的預(yù)測(cè)情況。并能從中計(jì)算精確值(Accuracy)、精確率(Precision)、召回率(Recall)、交并比(IoU)。


混淆矩陣(圖取自網(wǎng)絡(luò))

混淆矩陣是n*n的矩陣(n是類別),對(duì)角線上的是正確預(yù)測(cè)的數(shù)量。

  • 假設(shè)求dog的IoU
    請(qǐng)對(duì)照著附圖和代碼看


    dog_iou
true_dog = (7+2+28+111+18+801+13+17+0+3) # 上圖綠框
predict_dog = (1+0+8+48+13+801+4+17+1+0) # 上圖黃框
# 因?yàn)榉帜傅?01加了兩次,因此要減一次
iou_dog = 801 / true_dog + predict_dog - 801

按照dog求IoU的方法,對(duì)每個(gè)類別進(jìn)行求值,再求平均,就是語(yǔ)義分割模型的MIoU值。
理論上說,MIoU值越大(越接近1),模型效果越好。

P:Prediction預(yù)測(cè)值
G:Ground Truth真實(shí)值

MIoU

代碼實(shí)現(xiàn)

因?yàn)閚umpy能基于數(shù)組計(jì)算,因此MIoU的求解非常簡(jiǎn)潔。

  • 生成混淆矩陣
import numpy as np
def fast_hist(a, b, n):
    """
    生成混淆矩陣
    a 是形狀為(HxW,)的預(yù)測(cè)值
    b 是形狀為(HxW,)的真實(shí)值
    n 是類別數(shù)    
    """
    # 確保a和b在0~n-1的范圍內(nèi),k是(HxW,)的True和False數(shù)列
    k = (a >= 0) & (a < n)
    # 這句會(huì)返回混淆矩陣,具體在下面解釋
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

np.bincount(array)會(huì)返回array當(dāng)中每個(gè)元素的個(gè)數(shù)

例如下面例子的[1, 3, 5]當(dāng)中,0個(gè)0,2個(gè)1,0個(gè)2,1個(gè)3...以此類推

np.bincount([1, 3, 5, 1])
# 返回array([0, 2, 0, 1, 0, 1], dtype=int64)

從上例可以看出,返回值的長(zhǎng)度是輸入array中的最大值加1
最大值是5,返回的列表長(zhǎng)度為6(5+1)

np.bincount(array, minlength)中,minlength就是限制返回列表的最小長(zhǎng)度

np.bincount([1, 3, 5, 1], minlength=10)
# 返回array([0, 2, 0, 1, 0, 1, 0, 0, 0, 0], dtype=int64)

返回長(zhǎng)度為10,末尾填充了0,0,0...

看回源代碼,如果沒有越界情況,a[k]可以看成a,b[k]可以看成b

    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
  • 舉個(gè)例子:

前面8, 9, 4, 7, 6都預(yù)測(cè)正確
對(duì)于預(yù)測(cè)正確的像素來說,n * a + b就是對(duì)角線的值
假設(shè)n=10,有10類
n * a + b就是88, 99, 44, 77, 66

緊接著6預(yù)測(cè)成了5
因此n * a + b就是65

minlength=n**2是為了確?;煜仃囈欢ㄊ莕*n的大小
因?yàn)樽詈髍eshape(n, n)
88, 99, 44, 77, 66就是對(duì)角線上的值(如下圖紅框),65就是預(yù)測(cè)錯(cuò)誤,并且能真實(shí)反映把6預(yù)測(cè)成了5(如下圖藍(lán)框)


  • 計(jì)算IoU
def per_class_iou(hist):
    """
    hist傳入混淆矩陣(n, n)
    """
    # 因?yàn)橄旅嬗谐?,防止分母?的情況報(bào)錯(cuò)
    np.seterr(divide="ignore", invalid="ignore")
    # 交集:np.diag取hist的對(duì)角線元素
    # 并集:hist.sum(1)和hist.sum(0)分別按兩個(gè)維度相加,而對(duì)角線元素加了兩次,因此減一次
    iou = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
    # 把報(bào)錯(cuò)設(shè)回來
    np.seterr(divide="warn", invalid="warn")
    # 如果分母為0,結(jié)果是nan,會(huì)影響后續(xù)處理,因此把nan都置為0
    iou[np.isnan(res)] = 0.
    return iou

per_class_iou返回每個(gè)類別的iou,再求平均就能得到MIoU了。
下面是我進(jìn)行封裝后的代碼塊,歡迎討論優(yōu)化

def fast_hist(a, b, n):
    """
    Return a histogram that's the confusion matrix of a and b
    :param a: np.ndarray with shape (HxW,)
    :param b: np.ndarray with shape (HxW,)
    :param n: num of classes
    :return: np.ndarray with shape (n, n)
    """
    k = (a >= 0) & (a < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)


def per_class_iu(hist):
    """
    Calculate the IoU(Intersection over Union) for each class
    :param hist: np.ndarray with shape (n, n)
    :return: np.ndarray with shape (n,)
    """
    np.seterr(divide="ignore", invalid="ignore")
    res = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
    np.seterr(divide="warn", invalid="warn")
    res[np.isnan(res)] = 0.
    return res


class ComputeIoU(object):
    """
    IoU: Intersection over Union
    """

    def __init__(self):
        self.cfsmatrix = np.zeros((CONFIG.NUM_CLASSES, CONFIG.NUM_CLASSES), dtype="uint64")  # confusion matrix
        self.ious = dict()

    def get_cfsmatrix(self):
        return self.cfsmatrix

    def get_ious(self):
        self.ious = dict(zip(range(CONFIG.NUM_CLASSES), per_class_iu(self.cfsmatrix)))  # {0: iou, 1: iou, ...}
        return self.ious

    def get_miou(self, ignore=None):
        self.get_ious()
        total_iou = 0
        count = 0
        for key, value in self.ious.items():
            if isinstance(ignore, list) and key in ignore or \
                    isinstance(ignore, int) and key == ignore:
                continue
            total_iou += value
            count += 1
        return total_iou / count

    def __call__(self, pred, label):
        """
        :param pred: [N, H, W]
        :param label:  [N, H, W}
        Channel == 1
        """

        pred = pred.cpu().numpy()
        label = label.cpu().numpy()

        assert pred.shape == label.shape

        self.cfsmatrix += fast_hist(pred.reshape(-1), label.reshape(-1), CONFIG.NUM_CLASSES).astype("uint64")

調(diào)用方式:

from utils.metrics import ComputeIoU
for epoch in epoches:
  training code:
    pass
  valid code:
    for batch in val_loader:
      compute_iou = ComputeIoU()
      pred = model(image)
      pred = torch.argmax(F.softmax(pred, dim=1), dim=1)
      compute_iou(pred, label) # 這時(shí)會(huì)不斷累加混淆矩陣的值
    ious = compute_iou.get_ious()
    miou = compute_iou.get_miou(ignore=0)
    cfsmatrix = compute_iou.get_cfsmatrix()
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容