optimizer.state_dict()、optimizer.param_groups


net = t.nn.Linear(2, 3)
optimizer = t.optim.SGD(net.parameters(), lr=0.2)
for key, value in optimizer.state_dict().items():
print(key, value)
for i, param_group in enumerate(optimizer.param_groups):
print(i+1)
print(param_group)

1、optimizer.state_dict()

"""

state {}
param_groups [{'lr': 0.2, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140327302981024, 140327686399752]}]

"""

是一個字典,包括優化器的狀態(state)以及一些超參數信息(param_groups)

2、optimizer.param_groups

"""

1
{'params': [Parameter containing:
tensor([[-0.2604, 0.0777],
[-0.6420, 0.5030],
[-0.3879, -0.5129]], requires_grad=True), Parameter containing:
tensor([ 0.6245, 0.4680, -0.3667], requires_grad=True)], 'lr': 0.2, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}

"""

是param_groups是一個數組,數組內部包含n個字典

總結:state_dict()包括param_groups

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM