[PyTorch]論文pytorch復現中遇到的BUG


1. zip argument #1 must support iteration

在多gpu訓練的時候,自動把你的batch_size分成n_gpu份,每個gpu跑一些數據, 最后再合起來。我之所以出現這個bug是因為返回的時候 返回了一個常量。。

2. torch.nn.DataParallel

在使用torch.nn.DataParallel時候,要先把模型放在gpu上,再進行parallel。

3. model.state_dict()

一般在現有的網絡加載預訓練模型通常是找到預訓練模型在現有的model里面的參數,然后model進行更新,遇到一個bug, 發現加載預訓練模型的時候, 效果很差,跟參數沒有更新一樣,找了一大頓bug,最后才發現,之前是單gpu進行的預訓練,現在的模型使用的是多gpu, 打印現在模型的參數你會發現他所有的參數前面都加了一個module. 所以向以前一樣更新,沒有一個參數會被更新,因此寫了一個萬能模型參數加載函數。

pretrained_dict = checkpoint['state_dict']
model_dict = self.model.state_dict()
if checkpoint['config']['n_gpu'] > 1 and self.config['n_gpu'] == 1:
    new_dict = OrderedDict()
    for k, v in pretrained_dict.items():
        name = k[7:]
        new_dict[name] = v
    pretrained_dict = new_dict
elif checkpoint['config']['n_gpu'] == 1 and self.config['n_gpu'] > 1:
    new_dict = OrderedDict()
    for k, v in pretrained_dict.items():
        name = "module."+k
        new_dict[name] = v
    pretrained_dict = new_dict
print("The pretrained model's para is following")
for k, v in pretrained_dict.items():
    print(k)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.model.load_state_dict(model_dict)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM