pytorch的模型和參數是分開的,可以分別保存或加載模型和參數。
1、直接保存模型
# 保存模型
torch.save(model, 'model.pth')
# 加載模型
model = torch.load('model.pth')
2、分別加載模型的結構和參數
# 保存模型參數
torch.save(model.state_dict(), 'model.pth')
# 加載模型參數
model.load_state_dict(torch.load('model.pth')
CPU模型加載GPU參數
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
通過DataParalle使用多GPU
model=DataParalle(model)
#保存參數
torch.save(model.module.state_dict(), 'model.pth')
自己習慣用的代碼段
# 判斷gpu是否可用
use_cuda = torch.cuda.is_available()
# 是否使用多gpu
use_multi_gpu = True
# 默認加載的cpu的參數
model.load_state_dict(torch.load('model.pth')
if use_cuda:
model = model.cuda()
if use_multi_gpu:
model = DataParalle(model)
# 保存模型參數(一般保存cpu的參數比較好)
if use_multi_gpu:
torch.save(model.cpu().module.state_dict(), 'model.pth')
else:
torch.save(model.cpu().state_dict(), 'model.pth')
3、pytorch預訓練模型
加載預訓練模型和參數
resnet18 = models.resnet(pretrained=True)
只加載模型,不加載預訓練參數
# 加載模型
resnet18 = models.resnet18(pretrained=False)
# 加載預先下載好的預訓練模型參數
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
加載部分預訓練模型
resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""加載torchvision中的預訓練模型和參數后通過state_dict()方法提取參數 也可以直接從官方model_zoo下載: pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 將pretrained_dict里不屬於model_dict的鍵剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現有的model_dict
model_dict.update(pretrained_dict)
# 加載我們真正需要的state_dict
model.load_state_dict(model_dict)
【參考資料】
1、PyTorch學習:加載模型和參數
2、PyTorch使用cpu調用gpu訓練的模型
</div>