Pytorch實現Top1准確率和Top5准確率


之前一直不清楚Top1和Top5是什么,其實搞清楚了很簡單,就是兩種衡量指標,其中,Top1就是普通的Accuracy,Top5比Top1衡量標准更“嚴格”,

具體來講,比如一共需要分10類,每次分類器的輸出結果都是10個相加為1的概率值,Top1就是這十個值中最大的那個概率值對應的分類恰好正確的頻率,而Top5則是在十個概率值中從大到小排序出前五個,然后看看這前五個分類中是否存在那個正確分類,再計算頻率。Pytorch實現如下:

def evaluteTop1(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        #correct += torch.eq(pred, y).sum().item()
    return correct / total

def evaluteTop5(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1,5))
        y_resize = y.view(-1,1) _, pred
= logits.topk(maxk, 1, True, True) correct += torch.eq(pred, y_resize).sum().float().item() return correct / total

注意:y_resize = y.view(-1,1)是非常關鍵的一步,在correct的運算中,關鍵就是要pred和y_resize維度匹配,而原來的y是[128],128是batch大小;

pred的維度則是[128,10],假設這里是CIFAR10十分類;因此必須把y轉化成[128,1]這種維度,但是不能直接是y.view(128,1),因為遍歷整個數據集的時候,

最后一個batch大小並不是128,所以view()里面第一個size就設為-1未知,而確保第二個size是1就行

topk函數的具體用法參見https://blog.csdn.net/u014264373/article/details/86525621


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM