只保存參數信息
加載
checkpoint = torch.load(opt.resume)
model.load_state_dict(checkpoint)
保存
torch.save(self.state_dict(),file_path)
這而只保存了參數信息,讀取時也只有參數信息,模型結構需要手動編寫
保存整個模型
保存
torch.save(the_model, PATH)
加載:
the_model = torch.load(PATH)
有時候會看到加載時
model.load_state_dict(checkpoint['state_dic'])
這是因為checkpoint是一個字典,保存的key可以自己定義。
可以保存除參數信息之外的其它信息,如epoch等。
保存
torch.save({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, 'checkpoint.tar' )
加載
if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))
state_dict參考鏈接:
https://www.cnblogs.com/tingtin/p/13544489.html