import math from torch.autograd import Variable import torch import torch.nn as nn import warnings warnings.filterwarnings("ignore") def _calculate_fan_in_and_fan_out(tensor): print("***********_calculate_fan_in_and_fan_out****************") dimensions = tensor.dim() print("dimensions",dimensions) if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") if dimensions == 2: # Linear fan_in = tensor.size(1) fan_out = tensor.size(0) print("fan_in",fan_in) print("fan_out",fan_out) else: num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) print("num_input_fmaps",num_input_fmaps) print("num_output_fmaps", num_output_fmaps) receptive_field_size = 1 if tensor.dim() > 2: receptive_field_size = tensor[0][0].numel() print("receptive_field_size", receptive_field_size) fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out def xavier_uniform(tensor, gain=1): print("****************xavier_uniform*****************") fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) print("fan_in", fan_in) print("fan_out", fan_out) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) print("std",std) a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation print("a",a) return tensor.uniform_(-a, a) def xavier_normal(tensor, gain=1): print("****************xavier_normal*****************") fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) print("fan_in", fan_in) print("fan_out", fan_out) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) print("std", std) return tensor.normal_(0, std) w = torch.Tensor(3,5)
print("w",w) xavier_uniform=xavier_uniform(tensor=w,gain=1) print("xavier_uniform",xavier_uniform)
print("w",w) xavier_normal=xavier_normal(tensor=w,gain=1) print("xavier_normal",xavier_normal)
print("w",w)
'''
w tensor([[6.5103e-38, 0.0000e+00, 5.7453e-44, 0.0000e+00, nan],
[0.0000e+00, 1.3733e-14, 6.4076e+07, 2.0706e-19, 7.3909e+22],
[2.4176e-12, 1.1625e+33, 8.9605e-01, 1.1632e+33, 5.6003e-02]])
relu_gain 1.4142135623730951
****************xavier_uniform*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.7071067811865476
a 1.2247448713915892
xavier_uniform tensor([[ 0.0172, 1.0726, -0.5239, -0.2902, -0.5868],
[ 0.7199, 0.5818, 0.6772, -0.2686, 0.5099],
[ 1.1365, -0.8935, 0.0412, -0.4518, -0.9012]])
w tensor([[ 0.0172, 1.0726, -0.5239, -0.2902, -0.5868],
[ 0.7199, 0.5818, 0.6772, -0.2686, 0.5099],
[ 1.1365, -0.8935, 0.0412, -0.4518, -0.9012]])
****************xavier_normal*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.7071067811865476
xavier_normal tensor([[ 1.8155e+00, -1.5939e+00, 6.1080e-02, 1.3572e-01, -2.4904e-02],
[ 1.6301e-01, -7.8886e-01, 4.7981e-01, 2.5004e-02, 3.0120e-01],
[ 6.4991e-01, 2.0138e-01, -1.7021e-03, 1.2900e-02, 3.7923e-01]])
w tensor([[ 1.8155e+00, -1.5939e+00, 6.1080e-02, 1.3572e-01, -2.4904e-02],
[ 1.6301e-01, -7.8886e-01, 4.7981e-01, 2.5004e-02, 3.0120e-01],
[ 6.4991e-01, 2.0138e-01, -1.7021e-03, 1.2900e-02, 3.7923e-01]])
'''