源碼詳解Pytorch的state_dict和load_state_dict


在 Pytorch 中一種模型保存和加載的方式如下:

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

model.state_dict()其實返回的是一個OrderDict,存儲了網絡結構的名字和對應的參數,下面看看源代碼如何實現的。

state_dict

# torch.nn.modules.module.py
class Module(object):
	def state_dict(self, destination=None, prefix='', keep_vars=False):
		if destination is None:
			destination = OrderedDict()
			destination._metadata = OrderedDict()
		destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
		for name, param in self._parameters.items():
			if param is not None:
				destination[prefix + name] = param if keep_vars else param.data
		for name, buf in self._buffers.items():
			if buf is not None:
				destination[prefix + name] = buf if keep_vars else buf.data
		for name, module in self._modules.items():
			if module is not None:
				module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
		for hook in self._state_dict_hooks.values():
			hook_result = hook(self, destination, prefix, local_metadata)
			if hook_result is not None:
				destination = hook_result
		return destination

可以看到state_dict函數中遍歷了4中元素,分別是_paramters,_buffers,_modules_state_dict_hooks,前面三者在之前的文章已經介紹區別,最后一種就是在讀取state_dict時希望執行的操作,一般為空,所以不做考慮。另外有一點需要注意的是,在讀取Module時采用的遞歸的讀取方式,並且名字間使用.做分割,以方便后面load_state_dict讀取參數。

class MyModel(nn.Module):
	def __init__(self):
		super(MyModel, self).__init__()
		self.my_tensor = torch.randn(1) # 參數直接作為模型類成員變量
		self.register_buffer('my_buffer', torch.randn(1)) # 參數注冊為 buffer
		self.my_param = nn.Parameter(torch.randn(1))
		self.fc = nn.Linear(2,2,bias=False)
		self.conv = nn.Conv2d(2,1,1)
		self.fc2 = nn.Linear(2,2,bias=False)
		self.f3 = self.fc
	def forward(self, x):
		return x

model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([-0.3052])), ('my_buffer', tensor([0.5583])), ('fc.weight', tensor([[ 0.6322, -0.0255],
        [-0.4747, -0.0530]])), ('conv.weight', tensor([[[[ 0.3346]],

         [[-0.2962]]]])), ('conv.bias', tensor([0.5205])), ('fc2.weight', tensor([[-0.4949,  0.2815],
        [ 0.3006,  0.0768]])), ('f3.weight', tensor([[ 0.6322, -0.0255],
        [-0.4747, -0.0530]]))])

可以看到最后的確輸出了三種參數。

load_state_dict

下面的代碼中我們可以分成兩個部分看,

  1. load(self)

這個函數會遞歸地對模型進行參數恢復,其中的_load_from_state_dict的源碼附在文末。

首先我們需要明確state_dict這個變量表示你之前保存的模型參數序列,而_load_from_state_dict函數中的local_state 表示你的代碼中定義的模型的結構。

那么_load_from_state_dict的作用簡單理解就是假如我們現在需要對一個名為conv.weight的子模塊做參數恢復,那么就以遞歸的方式先判斷conv是否在staet__dictlocal_state中,如果不在就把conv添加到unexpected_keys中去,否則遞歸的判斷conv.weight是否存在,如果都存在就執行param.copy_(input_param),這樣就完成了conv.weight的參數拷貝。

  1. if strict:

這個部分的作用是判斷上面參數拷貝過程中是否有unexpected_keys或者missing_keys,如果有就報錯,代碼不能繼續執行。當然,如果strict=False,則會忽略這些細節。

def load_state_dict(self, state_dict, strict=True):
	missing_keys = []
	unexpected_keys = []
	error_msgs = []

	# copy state_dict so _load_from_state_dict can modify it
	metadata = getattr(state_dict, '_metadata', None)
	state_dict = state_dict.copy()
	if metadata is not None:
		state_dict._metadata = metadata

	def load(module, prefix=''):
		local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
		module._load_from_state_dict(
			state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
		for name, child in module._modules.items():
			if child is not None:
				load(child, prefix + name + '.')

	load(self)

	if strict:
		error_msg = ''
		if len(unexpected_keys) > 0:
			error_msgs.insert(
				0, 'Unexpected key(s) in state_dict: {}. '.format(
					', '.join('"{}"'.format(k) for k in unexpected_keys)))
		if len(missing_keys) > 0:
			error_msgs.insert(
				0, 'Missing key(s) in state_dict: {}. '.format(
					', '.join('"{}"'.format(k) for k in missing_keys)))

	if len(error_msgs) > 0:
		raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
						   self.__class__.__name__, "\n\t".join(error_msgs)))
  • _load_from_state_dict
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
						  missing_keys, unexpected_keys, error_msgs):
	for hook in self._load_state_dict_pre_hooks.values():
		hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

	local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
	local_state = {k: v.data for k, v in local_name_params if v is not None}

	for name, param in local_state.items():
		key = prefix + name
		if key in state_dict:
			input_param = state_dict[key]

			# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
			if len(param.shape) == 0 and len(input_param.shape) == 1:
				input_param = input_param[0]

			if input_param.shape != param.shape:
				# local shape should match the one in checkpoint
				error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
								  'the shape in current model is {}.'
								  .format(key, input_param.shape, param.shape))
				continue

			if isinstance(input_param, Parameter):
				# backwards compatibility for serialized parameters
				input_param = input_param.data
			try:
				param.copy_(input_param)
			except Exception:
				error_msgs.append('While copying the parameter named "{}", '
								  'whose dimensions in the model are {} and '
								  'whose dimensions in the checkpoint are {}.'
								  .format(key, param.size(), input_param.size()))
		elif strict:
			missing_keys.append(key)

	if strict:
		for key, input_param in state_dict.items():
			if key.startswith(prefix):
				input_name = key[len(prefix):]
				input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
				if input_name not in self._modules and input_name not in local_state:
					unexpected_keys.append(key)



微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com


如有意合作,歡迎私戳

郵箱:marsggbo@foxmail.com


2019-12-20 21:55:21




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM