之前一直不清楚Top1和Top5是什么,其實搞清楚了很簡單,就是兩種衡量指標,其中,Top1就是普通的Accuracy,Top5比Top1衡量標准更“嚴格”,
具體來講,比如一共需要分10類,每次分類器的輸出結果都是10個相加為1的概率值,Top1就是這十個值中最大的那個概率值對應的分類恰好正確的頻率,而Top5則是在十個概率值中從大到小排序出前五個,然后看看這前五個分類中是否存在那個正確分類,再計算頻率。Pytorch實現如下:
def evaluteTop1(model, loader): model.eval() correct = 0 total = len(loader.dataset) for x,y in loader: x,y = x.to(device), y.to(device) with torch.no_grad(): logits = model(x) pred = logits.argmax(dim=1) correct += torch.eq(pred, y).sum().float().item() #correct += torch.eq(pred, y).sum().item() return correct / total def evaluteTop5(model, loader): model.eval() correct = 0 total = len(loader.dataset) for x, y in loader: x,y = x.to(device),y.to(device) with torch.no_grad(): logits = model(x) maxk = max((1,5))
y_resize = y.view(-1,1) _, pred = logits.topk(maxk, 1, True, True) correct += torch.eq(pred, y_resize).sum().float().item() return correct / total
注意:y_resize = y.view(-1,1)是非常關鍵的一步,在correct的運算中,關鍵就是要pred和y_resize維度匹配,而原來的y是[128],128是batch大小;
pred的維度則是[128,10],假設這里是CIFAR10十分類;因此必須把y轉化成[128,1]這種維度,但是不能直接是y.view(128,1),因為遍歷整個數據集的時候,
最后一個batch大小並不是128,所以view()里面第一個size就設為-1未知,而確保第二個size是1就行
topk函數的具體用法參見https://blog.csdn.net/u014264373/article/details/86525621