第二次寫關於MSDNet方面的學習筆記,終於是弄懂了大致的全部代碼,也熟悉了其處理過程,先將第二階段的學習成果做一下介紹。
想看上一篇學習筆記(主要是翻譯大致內容)的可以看這個鏈接:https://www.cnblogs.com/liuyangcode/p/13700393.html
關於MSDNet的源代碼鏈接在上一篇學習筆記里已經貼出,本篇筆記主要是關於代碼中dynamic方式如何實現的一個詳細介紹,也就是所謂的budgeted batch classification
具體完整的代碼和注釋可以看這個鏈接:https://github.com/Liuyang829/testMSDNet/blob/main/adaptive_inference.py
文中所描述兩種評估方法:budgeted batch,anytime
文中所描述的budgeted batch是對於一個batch中的眾多樣本來說,共用一個固定的資源限制,可以自適應的在這一個batch中讓簡單樣本使用少一些的資源,讓復雜樣本使用多一些的資源。例如對於大型公司所需要處理的大量數據來說,如果能再簡單樣本上節省一點點的時間 對於之后的總體計算花費來說都是划算的。在這樣的一個batch中,如果計算資源總限制為B,這個batch中包含M個樣本數據,簡單樣本所分的資源應該小於B/M,復雜樣本應該大於B/M。但所描述的相對來說比較抽象,代碼上的實現遠比所描述的復雜,而且從下面的圖像上看,也可以看出這樣一條曲線需要多組測試數據與測試結果。
相比來說 ,文中所說的anytime prediction 指的是可以使網絡在任何給定的時間內給出預測結果。在代碼上的實現相對來說較為簡單,原文中的測試結果圖如下所示,橫坐標表示flops,縱坐標表示acc,至於flops是如何計算而來的,主要是看網絡結構。MSDNet的一個重要特點就是有多個分類器,樣本在測試時如果從淺層分類器輸出自然其flops會比較小,所以在模型結構確定下來后對於模型flops的計算主要取決於有多少個分類器,在逐層計算flops與prams的過程中,如果遇到Liner層,就會輸出一次結果。也就意味着整個網絡模型根據classifer位置的不同來確定flops的大小。對於anytime的測試過程,就是所有的數據在每個分類器都輸出一次結果,所以在下圖中有幾個點就代表有幾個分類器,也就是幾個blocks,代表在該分類器所有樣本全部退出所得到的正確率。
代碼實現 budgeted batch
該模塊代碼上的實現一直困惑着我,終究怎么樣才能達到在不同分類器輸出的動態推理,重點就在adapative_inference.py
總體來說動態退出實現思路在於,對於每一個分類器都找到一個退出的閾值,在執行測試的過程中,如果樣本在某一分類器的置信度超過了該分類器的閾值,則代表該樣本找到了屬於他的出口,所以核心在於每個分類器都要找到一個合適的閾值,這樣就可以對所有輸出的樣本進行判斷
首先在動態處理的主函數中對於所有的驗證集與測試集都放入calc_logit()函數來計算一次所有的輸出置信度,驗證集進行計算的目的用於計算閾值,測試集就是為了根據置信度和閾值判斷最后在哪輸出和最終結果。設立了一個40次的for循環的意義就在於生成了40組分類器的權重數據,代表了每個分類器要輸出百分之多少的數據樣本量,然后再根據這個限制去測試結果。
1 def dynamic_evaluate(model, test_loader, val_loader, args): 2 tester = Tester(model, args) 3 if os.path.exists(os.path.join(args.save, 'logits_single.pth')): 4 val_pred, val_target, test_pred, test_target = \ 5 torch.load(os.path.join(args.save, 'logits_single.pth')) 6 else: 7 # 這里對於驗證集與測試集分別計算每個分類器對於每一個樣本的一個預測結果置信度 8 val_pred, val_target = tester.calc_logit(val_loader) 9 test_pred, test_target = tester.calc_logit(test_loader) 10 torch.save((val_pred, val_target, test_pred, test_target), 11 os.path.join(args.save, 'logits_single.pth')) 12 13 flops = torch.load(os.path.join(args.save, 'flops.pth')) 14 15 with open(os.path.join(args.save, 'dynamic.txt'), 'w') as fout: 16 for p in range(1, 40): 17 print("*********************") 18 # 在這個for循環中生成一個0.05-1.95,以生成40組不同分類器的權重 19 _p = torch.FloatTensor(1).fill_(p * 1.0 / 20) 20 # 通過一個對數生成一個nBlocks維的tensor,就是不同分類器所需要處理的數據比例 21 probs = torch.exp(torch.log(_p) * torch.range(1, args.nBlocks)) 22 probs /= probs.sum() 23 # 利用驗證集去找閾值 24 acc_val, _, T = tester.dynamic_eval_find_threshold( 25 val_pred, val_target, probs, flops) 26 # 利用閾值給測試集安排出口與分類結果 27 acc_test, exp_flops = tester.dynamic_eval_with_threshold( 28 test_pred, test_target, flops, T) 29 print('valid acc: {:.3f}, test acc: {:.3f}, test flops: {:.2f}M'.format(acc_val, acc_test, exp_flops / 1e6)) 30 fout.write('{}\t{}\n'.format(acc_test, exp_flops.item()))
在calc_logit函數中,主要目的就是對兩個數據集合進行置信度的計算,我們都知道對於一般的神經網絡來說,分類器最后的輸出結果都是一個classes維的向量,代表了對於該樣本各個類別的置信度結果,向量中的最大的那個值所對應的下標也就是所對應的分類結果。
下面的代碼中可以看到,假設m個blocks,n個數據樣本,c個類別,首先生成空的m維的列表logits用來裝計算結果。將數據放進模型中輸出output,這個output[0]就代表着第1個分類器輸出的預測結果,以此類推,這里將預測結果再用softmax使數據分布更為明顯,放進logits中。對於這個輸出結果,可以判斷這是一個三階的矩陣,它的size大小為(m,n,c)。
至於為什么是這個size,m代表了m個分類器的輸出結果,所以首先是m,對於logits[0],size為(n,c),n個c維向量所組成的矩陣,代表着每個樣本數據進入模型的預測結果,所以共n行c列
1 def calc_logit(self, dataloader): 2 self.model.eval() 3 n_stage = self.args.nBlocks 4 logits = [[] for _ in range(n_stage)] 5 targets = [] 6 for i, (input, target) in enumerate(dataloader): 7 targets.append(target) 8 with torch.no_grad(): 9 input_var = torch.autograd.Variable(input) 10 # 模型生成每個分類器的預測結果 11 output = self.model(input_var) 12 if not isinstance(output, list): 13 output = [output] 14 # softmax相當於將值映射到0-1直接並且和為1 15 for b in range(n_stage): 16 _t = self.softmax(output[b]) 17 logits[b].append(_t) 18 19 if i % self.args.print_freq == 0: 20 print('Generate Logit: [{0}/{1}]'.format(i, len(dataloader))) 21 for b in range(n_stage): 22 logits[b] = torch.cat(logits[b], dim=0) 23 # logits相當於每個block輸出的結果,首先是nBlocks維, 24 # 因為有多個分類器輸出,logits[0]~logits[nBlocks-1] 25 # 對於每個輸出結果,肯定是輸入的數量num*classes類別置信度向量 然后根據size變成張量 26 size = (n_stage, logits[0].size(0), logits[0].size(1)) 27 ts_logits = torch.Tensor().resize_(size).zero_() 28 for b in range(n_stage): 29 ts_logits[b].copy_(logits[b]) 30 # 將targets也變成張量 31 targets = torch.cat(targets, dim=0) 32 ts_targets = torch.Tensor().resize_(size[1]).copy_(targets) 33 return ts_logits, ts_targets
計算出了置信度也組成了合適的數據結構,下一步就是最重要的核心操作,在不同的分類器權重組合中去找這40種threshold組合。這段代碼中用到很多pytorch的基本函數如max(dim,) sort(dim,) ge() type_as()等,具體用法在代碼注釋中都有寫。現在我來闡述一下這個的思路。
根據這一部分數據樣本的預測結果,首先將每一個樣本分類的置信度結果的最大值和下標都取出來,也就相當於是正常的分類結果。對於輸出的結果為一個m行n列的矩陣,每一行代表每一個分類器,每一列代表每個數據樣本,矩陣的數值表示在該分類器下預測結果最大的那個置信度,同時有一個相同size的矩陣並記錄其下標。對於生成的這個矩陣,在行的維度上進行一次排序,就意味着每一行的置信度向量都是從高到低排列的,每個分類器對於自己所分出來的最有自信的結果放在了最前面,同時也記錄了排序后對應的原下標。在之前已經設定好了每一個分類器的一個權重,根據權重可以分配每個分類器輸出的樣本數量,所以就從高到低對於每一個分類器輸出的結果進行分類,假設第一個分類器可以分出200個數據,那從高到低排在第200的那個置信度的值就是這個分類器的閾值,這里新開了一個n維的list用來記錄每一個樣本有沒有從網絡中輸出,樣本從前一個分類器出去之后自然后續不用再進行考慮,用作標志位,最后一個分類器的閾值設定為無窮小因為所有剩余的樣本都要在這個分類器出來。這樣大致就計算出了每個分類器的閾值。
1 def dynamic_eval_find_threshold(self, logits, targets, p, flops): 2 """ 3 logits: m * n * c 4 m: Stages-nblocks 5 n: Samples 6 c: Classes 7 """ 8 n_stage, n_sample, c = logits.size() 9 print(logits.size()) 10 # dim=2返回的max_preds是每一行的最大值,這個最大值也就是預測出來最大置信度的那個置信度 11 # argmax_greds就是這個最大預測值的下標,代表是第幾個 12 # 所以max_preds的維度是nblocks行,samples列,代表每個分類器出來的每個樣本的最大置信度預測結果 13 # 例如7行10列 就表示7個分類器有10個樣本進去輸出的預測結果置信度,arg代表的就是原來的下標代表第幾個 14 max_preds, argmax_preds = logits.max(dim=2, keepdim=False) 15 16 # 這里對max_preds在行上進行排序,_為排序后的結果,sort_id就是對應原來矩陣中的下標 17 _, sorted_idx = max_preds.sort(dim=1, descending=True) 18 # 樣本個數個 19 filtered = torch.zeros(n_sample) 20 # 用來裝閾值的 21 T = torch.Tensor(n_stage).fill_(1e8) 22 # p是每個分類器的權重 23 # 對一個中間分類器而言,已經設定好了從這個分類其中分出去的數量, 24 # 那就把這個分類器出來的所有結果排序后前n個當作這個分類器可以退出的, 25 # 那么第n個分出去的那個預測結果的置信度就是閾值 26 for k in range(n_stage - 1): 27 acc, count = 0.0, 0 28 # 計划每個分類器按照權重分出去的個數 29 out_n = math.floor(n_sample * p[k]) 30 for i in range(n_sample): 31 # ori_idx表示 32 ori_idx = sorted_idx[k][i] 33 # filter記錄着每個樣本是否已經退出 只有還沒有退出的才能作為計算 起標記作用 34 if filtered[ori_idx] == 0: 35 count += 1 36 # 到了預計的退出數量就記下來那個閾值 37 if count == out_n: 38 T[k] = max_preds[k][ori_idx] 39 break 40 # ge判斷張量內每一個數值大小的函數 type_as為了該表數據類型后才能用ge進行比較 add_為加的操作 41 # ge的比較結果在於在本層分類器中有多少個樣本已經退出,得出來的理想結果應該是[1,1,1,1,1...0,0,0,...] 42 # filter本來為一個sample維的0向量,加上比較的結果后就說明標記好了已經退出去的樣本。 43 44 filtered.add_(max_preds[k].ge(T[k]).type_as(filtered)) 45 46 T[n_stage -1] = -1e8 # accept all of the samples at the last stage 47 48 # 計算正確率 49 acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) 50 acc, expected_flops = 0, 0 51 for i in range(n_sample): 52 gold_label = targets[i] 53 for k in range(n_stage): 54 if max_preds[k][i].item() >= T[k]: # force the sample to exit at k 55 if int(gold_label.item()) == int(argmax_preds[k][i].item()): 56 acc += 1 57 acc_rec[k] += 1 58 exp[k] += 1 59 break 60 acc_all = 0 61 # 根據比例計算flops 62 for k in range(n_stage): 63 _t = 1.0 * exp[k] / n_sample 64 expected_flops += _t * flops[k] 65 acc_all += acc_rec[k] 66 67 return acc * 100.0 / n_sample, expected_flops, T
得到了閾值就很方便后續的計算了,同樣對於置信度最大值的矩陣進行排序,從大到小輸入網絡中依次進行判斷,如果置信度大於設定的閾值就輸出,同時計算出每一個分類器所輸出的樣本個數,根據不同分類器所對應的flops的不同計算出實際所用的flops。同時,根據每一個分類器所分對的樣本數量計算出了一個總的正確分類個數,計算出了測試正確率。
1 def dynamic_eval_with_threshold(self, logits, targets, flops, T): 2 # 和上面類似 接下來是一個比較的過程 3 n_stage, n_sample, _ = logits.size() 4 max_preds, argmax_preds = logits.max(dim=2, keepdim=False) # take the max logits as confidence 5 # acc為總的正確個數 acc_rec為每一個分類器的正確個數 6 acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) 7 acc, expected_flops = 0, 0 8 for i in range(n_sample): 9 gold_label = targets[i] 10 for k in range(n_stage): 11 if max_preds[k][i].item() >= T[k]: # force to exit at k 12 _g = int(gold_label.item()) 13 _pred = int(argmax_preds[k][i].item()) 14 if _g == _pred: 15 acc += 1 16 acc_rec[k] += 1 17 exp[k] += 1 18 break 19 # 根據每個分類器退出數量計算出flops,flops計算的時候本身就是一個分類器一個flops 20 acc_all, sample_all = 0, 0 21 for k in range(n_stage): 22 _t = exp[k] * 1.0 / n_sample 23 sample_all += exp[k] 24 expected_flops += _t * flops[k] 25 acc_all += acc_rec[k] 26 return acc * 100.0 / n_sample, expected_flops
這里基本上完整敘述了動態推理的過程,到此才真正理解了什么是所謂的動態推理。但其實我在思考問題在於真正實際使用時,是在測試的時候實現動態推理,這里是在訓練集中提前分了一部分數據出來作為驗證集來計算這個threshold,那如果真正使用中是不是應該把這一部分放在訓練部分比較更為合適些,這樣才能實現真正的動態推理測試。
以上均為個人理解,如果還有任何對於本文的疑問,歡迎留言討論