pytorch的模型和參數是分開的,可以分別保存或加載模型和參數。
pytorch有兩種模型保存方式:
一、保存整個神經網絡的的結構信息和模型參數信息,save的對象是網絡net
二、只保存神經網絡的訓練模型參數,save的對象是net.state_dict()
對應兩種保存模型的方式,pytorch也有兩種加載模型的方式。對應第一種保存方式,加載模型時通過torch.load('.pth')直接初始化新的神經網絡對象;對應第二種保存方式,需要首先導入對應的網絡,再通過net.load_state_dict(torch.load('.pth'))完成模型參數的加載。
在網絡比較大的時候,第一種方法會花費較多的時間。
1. 直接加載模型和參數 加載別人訓練好的模型: # 保存和加載整個模型 torch.save(model_object, 'resnet.pth') model = torch.load('resnet.pth') 2. 分別加載網絡的結構和參數 # 將my_resnet模型儲存為my_resnet.pth torch.save(my_resnet.state_dict(), "my_resnet.pth") # 加載resnet,模型存放在my_resnet.pth my_resnet.load_state_dict(torch.load("my_resnet.pth")) 其中my_resnet是my_resnet.pth對應的網絡結構。 3. pytorch預訓練模型 1)加載預訓練模型和參數 resnet18 = models.resnet18(pretrained=True) 這里是直接調用pytorch中的常用模型 # PyTorch中的torchvision里有很多常用的模型,可以直接調用: import torchvision.models as models resnet101 = models.resnet18() alexnet = models.alexnet() squeezenet = models.squeezenet1_0() densenet = models.densenet_161() 2)只加載模型,不加載預訓練參數 # 導入模型結構 resnet18 = models.resnet18(pretrained=False) # 加載預先下載好的預訓練參數到resnet18 resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth')) 3)加載部分預訓練模型 resnet152 = models.resnet152(pretrained=True) pretrained_dict = resnet152.state_dict() """加載torchvision中的預訓練模型和參數后通過state_dict()方法提取參數 也可以直接從官方model_zoo下載: pretrained_dict = model_zoo.load_url(model_urls['resnet152'])""" model_dict = model.state_dict() # 將pretrained_dict里不屬於model_dict的鍵剔除掉 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新現有的model_dict model_dict.update(pretrained_dict) # 加載我們真正需要的state_dict model.load_state_dict(model_dict)