PyTorch模型讀寫、參數初始化、Finetune


使用了一段時間PyTorch,感覺愛不釋手(0-0),聽說現在已經有C++接口。在應用過程中不可避免需要使用Finetune/參數初始化/模型加載等。

模型保存/加載

1.所有模型參數

訓練過程中,有時候會由於各種原因停止訓練,這時候我們訓練過程中就需要注意將每一輪epoch的模型保存(一般保存最好模型與當前輪模型)。一般使用pytorch里面推薦的保存方法。該方法保存的是模型的參數。

#保存模型到checkpoint.pth.tar torch.save(model.module.state_dict(), checkpoint.pth.tar)

對應的加載模型方法為(這種方法需要先反序列化模型獲取參數字典,因此必須先load模型,再load_state_dict):

mymodel.load_state_dict(torch.load(checkpoint.pth.tar))

有了上面的保存后,現以一個例子說明如何在inference AND/OR resume train使用。

#保存模型的狀態,可以設置一些參數,后續可以使用 state = {'epoch': epoch + 1,#保存的當前輪數 'state_dict': mymodel.state_dict(),#訓練好的參數 'optimizer': optimizer.state_dict(),#優化器參數,為了后續的resume 'best_pred': best_pred#當前最好的精度 ,....,...} #保存模型到checkpoint.pth.tar torch.save(state, checkpoint.pth.tar) #如果是best,則復制過去 if is_best: shutil.copyfile(filename, directory + 'model_best.pth.tar') checkpoint = torch.load('model_best.pth.tar') model.load_state_dict(checkpoint['state_dict'])#模型參數 optimizer.load_state_dict(checkpoint['optimizer'])#優化參數 epoch = checkpoint['epoch']#epoch,可以用於更新學習率等 #有了以上的東西,就可以繼續重新訓練了,也就不需要擔心停止程序重新訓練。 train/eval .... ....

上面是pytorch建議使用的方法,當然還有第二種方法。這種方法靈活性不高,不推薦。

#保存 torch.save(mymodel,checkpoint.pth.tar) #加載 mymodel = torch.load(checkpoint.pth.tar)

2.部分模型參數

在很多時候,我們加載的是已經訓練好的模型,而訓練好的模型可能與我們定義的模型不完全一樣,而我們只想使用一樣的那些層的參數。

有幾種解決方法:

(1)直接在訓練好的模型開始搭建自己的模型,就是先加載訓練好的模型,然后再它基礎上定義自己的模型;

model_ft = models.resnet18(pretrained=use_pretrained) self.conv1 = model_ft.conv1 self.bn = model_ft.bn ... ...

(2) 自己定義好模型,直接加載模型

#第一種方法: mymodelB = TheModelBClass(*args, **kwargs) # strict=False,設置為false,只保留鍵值相同的參數 mymodelB.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) #第二種方法: # 加載模型 model_pretrained = models.resnet18(pretrained=use_pretrained) # mymodel's state_dict, # 如: conv1.weight # conv1.bias mymodelB_dict = mymodelB.state_dict() # 將model_pretrained的建與自定義模型的建進行比較,剔除不同的 pretrained_dict = {k: v for k, v in model_pretrained.items() if k in mymodelB_dict} # 更新現有的model_dict mymodelB_dict.update(pretrained_dict) # 加載我們真正需要的state_dict mymodelB.load_state_dict(mymodelB_dict) # 方法2可能更直觀一些

參數初始化

第二個問題是參數初始化問題,在很多代碼里面都會使用到,畢竟不是所有的都是有預訓練參數。這時就需要對不是與預訓練參數進行初始化。pytorch里面的每個Tensor其實是對Variabl的封裝,其包含data、grad等接口,因此可以用這些接口直接賦值。這里也提供了怎樣把其他框架(caffe/tensorflow/mxnet/gluonCV等)訓練好的模型參數直接賦值給pytorch.其實就是對data直接賦值。

pytorch提供了初始化參數的方法:

 def weight_init(m): if isinstance(m,nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0,math.sqrt(2./n)) elif isinstance(m,nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()

但一般如果沒有很大需求初始化參數,也沒有問題(不確定性能是否有影響的情況下),pytorch內部是有默認初始化參數的。

Fintune

最后就是精調了,我們平時做實驗,至少backbone是用預訓練的模型,將其用作特征提取器,或者在它上面做精調。

用於特征提取的時候,要求特征提取部分參數不進行學習,而pytorch提供了requires_grad參數用於確定是否進去梯度計算,也即是否更新參數。以下以minist為例,用resnet18作特征提取:

#加載預訓練模型 model = torchvision.models.resnet18(pretrained=True) #遍歷每一個參數,將其設置為不更新參數,即不學習 for param in model.parameters(): param.requires_grad = False # 將全連接層改為mnist所需的10類,注意:這樣更改后requires_grad默認為True model.fc = nn.Linear(512, 10) # 優化 optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9) 

用於全局精調時,我們一般對不同的層需要設置不同的學習率,預訓練的層學習率小一點,其他層大一點。這要怎么做呢?

# 加載預訓練模型 model = torchvision.models.resnet18(pretrained=True) model.fc = nn.Linear(512, 10) # 參考:https://blog.csdn.net/u012759136/article/details/65634477 ignored_params = list(map(id, model.fc.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) # 對不同參數設置不同的學習率 params_list = [{'params': base_params, 'lr': 0.001},] params_list.append({'params': model.fc.parameters(), 'lr': 0.01}) optimizer = torch.optim.SGD(params_list, 0.001momentum=args.momentum, weight_decay=args.weight_decay)

最后整理一下目前,pytorch預訓練的基礎模型:

(1)torchvision

torchvision里面已經提供了不同的預訓練模型,一般也夠用了。

pytorch/visiongithub.com圖標

包含了alexnet/densenet各種版本(densenet121/densenet169/densenet201/densenet161)/inception_v3/resnet各種版本(resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152')/SqueezeNet各種版本( 'squeezenet1_0', 'squeezenet1_1')/VGG各種版本( 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19')

(2)其他預訓練好的模型,如,SENet/NASNet等。

Cadene/pretrained-models.pytorchgithub.com

(3)gluonCV轉pytorch的模型,包括,分類網絡,分割網絡等,這里的精度均比其他框架高幾個百分點。

zhanghang1989/gluoncv-torchgithub.com圖標


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM