報錯的原因在於Pytorch0.4之后,在BN層后新增加了track_running_stats這個參數。
在調用預訓練參數模型是,官方給定的預訓練模型是在pytorch0.4之前,因此,調用預訓練參數時,需要過濾掉“num_batches_tracked”。
以resnet50為例:
為了加載不同層的權重,采用兩個函數,如下:load_partial_param用於加載layer1, layer2, layer3, layer4的權重權重,load_specific_param用於加載第一層的權重參數。
為了避免“num_batches_tracked”報錯,采用下面的代碼即可,更改部分為紅色字體(方法簡單,但可以滿足要求)。
def load_partial_param(self, state_dict, model_index, model_path): param_dict = torch.load(model_path) for i in state_dict: key = 'layer{}.'.format(model_index)+i if 'tracked' in key[-7:]: continue state_dict[i].copy_(param_dict[key]) del param_dict
def load_specific_param(self, state_dict, param_name, model_path): param_dict = torch.load(model_path) for i in state_dict: key = param_name + '.' + i if 'num_batches_tracked' in key: continue state_dict[i].copy_(param_dict[key]) del param_dict