在深度學習中參數的初始化十分重要,良好的初始化能讓模型更快收斂,並達到更高水平,而糟糕的初始化則可能使得模型迅速癱瘓。PyTorch中nn.Module的模塊參數都采取了較為合理的初始化策略,因此一般不用我們考慮,當然我們也可以用自定義初始化去代替系統的默認初始化。而當我們在使用Parameter時,自定義初始化則尤為重要,因t.Tensor()返回的是內存中的隨機數,很可能會有極大值,這在實際訓練網絡中會造成溢出或者梯度消失。PyTorch中nn.init
模塊就是專門為初始化而設計,如果某種初始化策略nn.init
不提供,用戶也可以自己直接初始化。
# 利用nn.init初始化
from torch.nn import init
linear = nn.Linear(3, 4)
t.manual_seed(1)
# 等價於 linear.weight.data.normal_(0, std)
init.xavier_normal_(linear.weight)
# 直接初始化
import math
t.manual_seed(1)
# xavier初始化的計算公式
std = math.sqrt(2)/math.sqrt(7.)
linear.weight.data.normal_(0,std)
# 對模型的所有參數進行初始化
for name, params in net.named_parameters():
if name.find('linear') != -1:
# init linear
params[0] # weight
params[1] # bias
elif name.find('conv') != -1:
pass
elif name.find('norm') != -1:
pass
補充
xavier初始化
torch.nn.init.xavier_uniform(tensor, gain=1)
對於輸入的tensor或者變量,通過論文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化數據。
初始化服從均勻分布U(−a,a)U(−a,a),其中a=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√×3–√a=gain×2/(fan_in+fan_out)×3,該初始化方法也稱Glorot initialisation。
參數:
tensor:n維的 torch.Tensor 或者 autograd.Variable類型的數據
a:可選擇的縮放參數
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))
torch.nn.init.xavier_normal(tensor, gain=1)
對於輸入的tensor或者變量,通過論文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化數據。初始化服從高斯分布N(0,std)N(0,std),其中std=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√std=gain×2/(fan_in+fan_out),該初始化方法也稱Glorot initialisation。
參數:
tensor:n維的 torch.Tensor 或者 autograd.Variable類型的數據
a:可選擇的縮放參數
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_normal(w)
另外在torch.Tensor下還定義了一些in-place的函數:
-
torch.Tensor.bernoulli_()
- in-place version oftorch.bernoulli()
,伯努利分布 -
torch.Tensor.cauchy_()
- numbers drawn from the Cauchy distribution,柯西分布 -
torch.Tensor.exponential_()
- numbers drawn from the exponential distribution,指數分布 -
torch.Tensor.geometric_()
- elements drawn from the geometric distribution,幾何分布 -
torch.Tensor.log_normal_()
- samples from the log-normal distribution,對數正太分布 -
torch.Tensor.normal_()
- in-place version oftorch.normal()
,正太分布 -
torch.Tensor.random_()
- numbers sampled from the discrete uniform distribution,均勻分布 -
torch.Tensor.uniform_()
- numbers sampled from the continuous uniform distribution,連續均勻分布每個的參數不同,像均勻分布等有均值和方差。