[pytorch] 自定義激活函數中的注意事項


如何在pytorch中使用自定義的激活函數?

如果自定義的激活函數是可導的,那么可以直接寫一個python function來定義並調用,因為pytorch的autograd會自動對其求導。

如果自定義的激活函數不是可導的,比如類似於ReLU的分段可導的函數,需要寫一個繼承torch.autograd.Function的類,並自行定義forward和backward的過程

在pytorch中提供了定義新的autograd function的tutorial: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html, tutorial以ReLU為例介紹了在forward, backward中需要自行定義的內容。

 1 import torch
 2 
 3 
 4 class MyReLU(torch.autograd.Function):
 5     """
 6     We can implement our own custom autograd Functions by subclassing
 7     torch.autograd.Function and implementing the forward and backward passes
 8     which operate on Tensors.
 9     """
10 
11     @staticmethod
12     def forward(ctx, input):
13         """
14         In the forward pass we receive a Tensor containing the input and return
15         a Tensor containing the output. ctx is a context object that can be used
16         to stash information for backward computation. You can cache arbitrary
17         objects for use in the backward pass using the ctx.save_for_backward method.
18         """
19         ctx.save_for_backward(input)
20         return input.clamp(min=0)
21 
22     @staticmethod
23     def backward(ctx, grad_output):
24         """
25         In the backward pass we receive a Tensor containing the gradient of the loss
26         with respect to the output, and we need to compute the gradient of the loss
27         with respect to the input.
28         """
29         input, = ctx.saved_tensors
30         grad_input = grad_output.clone()
31         grad_input[input < 0] = 0
32         return grad_input
33 
34 
35 dtype = torch.float
36 device = torch.device("cpu")
37 # device = torch.device("cuda:0") # Uncomment this to run on GPU
38 
39 # N is batch size; D_in is input dimension;
40 # H is hidden dimension; D_out is output dimension.
41 N, D_in, H, D_out = 64, 1000, 100, 10
42 
43 # Create random Tensors to hold input and outputs.
44 x = torch.randn(N, D_in, device=device, dtype=dtype)
45 y = torch.randn(N, D_out, device=device, dtype=dtype)
46 
47 # Create random Tensors for weights.
48 w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
49 w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
50 
51 learning_rate = 1e-6
52 for t in range(500):
53     # To apply our Function, we use Function.apply method. We alias this as 'relu'.
54     relu = MyReLU.apply
55 
56     # Forward pass: compute predicted y using operations; we compute
57     # ReLU using our custom autograd operation.
58     y_pred = relu(x.mm(w1)).mm(w2)
59 
60     # Compute and print loss
61     loss = (y_pred - y).pow(2).sum()
62     print(t, loss.item())
63 
64     # Use autograd to compute the backward pass.
65     loss.backward()
66 
67     # Update weights using gradient descent
68     with torch.no_grad():
69         w1 -= learning_rate * w1.grad
70         w2 -= learning_rate * w2.grad
71 
72         # Manually zero the gradients after updating weights
73         w1.grad.zero_()
74         w2.grad.zero_()

但是如果定義ReLU函數時,沒有使用以上正確的方法,而是直接自定義的函數,會出現什么問題呢?

這里對比了使用以上MyReLU和自定義函數:no_back的實驗結果。

1 def no_back(x):
2     return x * (x > 0).float()

代碼:

N, D_in, H, D_out = 2, 3, 4, 5

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
origin_w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
origin_w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-3

def myReLU(func, x, y, origin_w1, origin_w2, learning_rate,N = 2, D_in = 3, H = 4, D_out = 5):
    w1 = deepcopy(origin_w1)
    w2 = deepcopy(origin_w2)
    for t in range(5):
        # Forward pass: compute predicted y using operations; we compute
        # ReLU using our custom autograd operation.
        y_pred = func(x.mm(w1)).mm(w2)

        # Compute and print loss
        loss = (y_pred - y).pow(2).sum()
        print("------", t, loss.item(), "------------")

        # Use autograd to compute the backward pass.
        loss.backward()

        # Update weights using gradient descent
        with torch.no_grad():
            print('w1 = ')
            print(w1)
            print('---------------------')
            print("x.mm(w1) = ")
            print(x.mm(w1))
            print('---------------------')
            print('func(x.mm(w1))')
            print(func(x.mm(w1)))
            print('---------------------')
            print("w1.grad:", w1.grad)
            # print("w2.grad:",w2.grad)
            print('---------------------')

            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after updating weights
            w1.grad.zero_()
            w2.grad.zero_()
            print('========================')
            print()


myReLU(func = MyReLU.apply, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)
print('============')
print('============')
print('============')
myReLU(func = no_back, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)

對於使用了MyReLU.apply的實驗結果為:

 1 ------ 0 20.18220329284668 ------------
 2 w1 = 
 3 tensor([[ 0.7070,  2.5772,  0.7987,  2.2287],
 4         [ 0.7425, -0.6309,  0.3268, -1.5072],
 5         [ 0.6930, -2.6128,  0.1949,  0.8819]], requires_grad=True)
 6 ---------------------
 7 x.mm(w1) = 
 8 tensor([[-0.9788,  1.0135, -0.4164,  1.8834],
 9         [-0.7692, -1.8556, -0.7085, -0.9849]])
10 ---------------------
11 func(x.mm(w1))
12 tensor([[0.0000, 1.0135, 0.0000, 1.8834],
13         [0.0000, 0.0000, 0.0000, 0.0000]])
14 ---------------------
15 w1.grad: tensor([[  0.0000,   0.0499,   0.0000,   0.1881],
16         [  0.0000,  -4.4962,   0.0000, -16.9378],
17         [  0.0000,  -0.2401,   0.0000,  -0.9043]])
18 ---------------------
19 ========================
20 
21 ------ 1 19.546737670898438 ------------
22 w1 = 
23 tensor([[ 0.7070,  2.5772,  0.7987,  2.2285],
24         [ 0.7425, -0.6265,  0.3268, -1.4903],
25         [ 0.6930, -2.6126,  0.1949,  0.8828]], requires_grad=True)
26 ---------------------
27 x.mm(w1) = 
28 tensor([[-0.9788,  1.0078, -0.4164,  1.8618],
29         [-0.7692, -1.8574, -0.7085, -0.9915]])
30 ---------------------
31 func(x.mm(w1))
32 tensor([[0.0000, 1.0078, 0.0000, 1.8618],
33         [0.0000, 0.0000, 0.0000, 0.0000]])
34 ---------------------
35 w1.grad: tensor([[  0.0000,   0.0483,   0.0000,   0.1827],
36         [  0.0000,  -4.3446,   0.0000, -16.4493],
37         [  0.0000,  -0.2320,   0.0000,  -0.8782]])
38 ---------------------
39 ========================
40 
41 ------ 2 18.94647789001465 ------------
42 w1 = 
43 tensor([[ 0.7070,  2.5771,  0.7987,  2.2283],
44         [ 0.7425, -0.6221,  0.3268, -1.4738],
45         [ 0.6930, -2.6123,  0.1949,  0.8837]], requires_grad=True)
46 ---------------------
47 x.mm(w1) = 
48 tensor([[-0.9788,  1.0023, -0.4164,  1.8409],
49         [-0.7692, -1.8591, -0.7085, -0.9978]])
50 ---------------------
51 func(x.mm(w1))
52 tensor([[0.0000, 1.0023, 0.0000, 1.8409],
53         [0.0000, 0.0000, 0.0000, 0.0000]])
54 ---------------------
55 w1.grad: tensor([[  0.0000,   0.0467,   0.0000,   0.1775],
56         [  0.0000,  -4.2009,   0.0000, -15.9835],
57         [  0.0000,  -0.2243,   0.0000,  -0.8534]])
58 ---------------------
59 ========================
60 
61 ------ 3 18.378826141357422 ------------
62 w1 = 
63 tensor([[ 0.7070,  2.5771,  0.7987,  2.2281],
64         [ 0.7425, -0.6179,  0.3268, -1.4578],
65         [ 0.6930, -2.6121,  0.1949,  0.8846]], requires_grad=True)
66 ---------------------
67 x.mm(w1) = 
68 tensor([[-0.9788,  0.9969, -0.4164,  1.8206],
69         [-0.7692, -1.8607, -0.7085, -1.0040]])
70 ---------------------
71 func(x.mm(w1))
72 tensor([[0.0000, 0.9969, 0.0000, 1.8206],
73         [0.0000, 0.0000, 0.0000, 0.0000]])
74 ---------------------
75 w1.grad: tensor([[  0.0000,   0.0451,   0.0000,   0.1726],
76         [  0.0000,  -4.0644,   0.0000, -15.5391],
77         [  0.0000,  -0.2170,   0.0000,  -0.8296]])
78 ---------------------
79 ========================
80 
81 ------ 4 17.841421127319336 ------------
82 w1 = 
83 tensor([[ 0.7070,  2.5770,  0.7987,  2.2280],
84         [ 0.7425, -0.6138,  0.3268, -1.4423],
85         [ 0.6930, -2.6119,  0.1949,  0.8854]], requires_grad=True)
86 ---------------------
87 x.mm(w1) = 
88 tensor([[-0.9788,  0.9918, -0.4164,  1.8008],
89         [-0.7692, -1.8623, -0.7085, -1.0100]])
90 ---------------------
91 func(x.mm(w1))
92 tensor([[0.0000, 0.9918, 0.0000, 1.8008],
93         [0.0000, 0.0000, 0.0000, 0.0000]])
94 ---------------------
95 w1.grad: tensor([[  0.0000,   0.0437,   0.0000,   0.1679],
96         [  0.0000,  -3.9346,   0.0000, -15.1145],
97         [  0.0000,  -0.2101,   0.0000,  -0.8070]])
98 ---------------------
99 ========================
View Code

對於使用了no_back的實驗結果為:

 1 ------ 0 20.18220329284668 ------------
 2 w1 = 
 3 tensor([[ 0.7070,  2.5772,  0.7987,  2.2287],
 4         [ 0.7425, -0.6309,  0.3268, -1.5072],
 5         [ 0.6930, -2.6128,  0.1949,  0.8819]], requires_grad=True)
 6 ---------------------
 7 x.mm(w1) = 
 8 tensor([[-0.9788,  1.0135, -0.4164,  1.8834],
 9         [-0.7692, -1.8556, -0.7085, -0.9849]])
10 ---------------------
11 func(x.mm(w1))
12 tensor([[-0.0000, 1.0135, -0.0000, 1.8834],
13         [-0.0000, -0.0000, -0.0000, -0.0000]])
14 ---------------------
15 w1.grad: tensor([[  0.0000,   0.0499,   0.0000,   0.1881],
16         [  0.0000,  -4.4962,   0.0000, -16.9378],
17         [  0.0000,  -0.2401,   0.0000,  -0.9043]])
18 ---------------------
19 ========================
20 
21 ------ 1 19.546737670898438 ------------
22 w1 = 
23 tensor([[ 0.7070,  2.5772,  0.7987,  2.2285],
24         [ 0.7425, -0.6265,  0.3268, -1.4903],
25         [ 0.6930, -2.6126,  0.1949,  0.8828]], requires_grad=True)
26 ---------------------
27 x.mm(w1) = 
28 tensor([[-0.9788,  1.0078, -0.4164,  1.8618],
29         [-0.7692, -1.8574, -0.7085, -0.9915]])
30 ---------------------
31 func(x.mm(w1))
32 tensor([[-0.0000, 1.0078, -0.0000, 1.8618],
33         [-0.0000, -0.0000, -0.0000, -0.0000]])
34 ---------------------
35 w1.grad: tensor([[  0.0000,   0.0483,   0.0000,   0.1827],
36         [  0.0000,  -4.3446,   0.0000, -16.4493],
37         [  0.0000,  -0.2320,   0.0000,  -0.8782]])
38 ---------------------
39 ========================
40 
41 ------ 2 18.94647789001465 ------------
42 w1 = 
43 tensor([[ 0.7070,  2.5771,  0.7987,  2.2283],
44         [ 0.7425, -0.6221,  0.3268, -1.4738],
45         [ 0.6930, -2.6123,  0.1949,  0.8837]], requires_grad=True)
46 ---------------------
47 x.mm(w1) = 
48 tensor([[-0.9788,  1.0023, -0.4164,  1.8409],
49         [-0.7692, -1.8591, -0.7085, -0.9978]])
50 ---------------------
51 func(x.mm(w1))
52 tensor([[-0.0000, 1.0023, -0.0000, 1.8409],
53         [-0.0000, -0.0000, -0.0000, -0.0000]])
54 ---------------------
55 w1.grad: tensor([[  0.0000,   0.0467,   0.0000,   0.1775],
56         [  0.0000,  -4.2009,   0.0000, -15.9835],
57         [  0.0000,  -0.2243,   0.0000,  -0.8534]])
58 ---------------------
59 ========================
60 
61 ------ 3 18.378826141357422 ------------
62 w1 = 
63 tensor([[ 0.7070,  2.5771,  0.7987,  2.2281],
64         [ 0.7425, -0.6179,  0.3268, -1.4578],
65         [ 0.6930, -2.6121,  0.1949,  0.8846]], requires_grad=True)
66 ---------------------
67 x.mm(w1) = 
68 tensor([[-0.9788,  0.9969, -0.4164,  1.8206],
69         [-0.7692, -1.8607, -0.7085, -1.0040]])
70 ---------------------
71 func(x.mm(w1))
72 tensor([[-0.0000, 0.9969, -0.0000, 1.8206],
73         [-0.0000, -0.0000, -0.0000, -0.0000]])
74 ---------------------
75 w1.grad: tensor([[  0.0000,   0.0451,   0.0000,   0.1726],
76         [  0.0000,  -4.0644,   0.0000, -15.5391],
77         [  0.0000,  -0.2170,   0.0000,  -0.8296]])
78 ---------------------
79 ========================
80 
81 ------ 4 17.841421127319336 ------------
82 w1 = 
83 tensor([[ 0.7070,  2.5770,  0.7987,  2.2280],
84         [ 0.7425, -0.6138,  0.3268, -1.4423],
85         [ 0.6930, -2.6119,  0.1949,  0.8854]], requires_grad=True)
86 ---------------------
87 x.mm(w1) = 
88 tensor([[-0.9788,  0.9918, -0.4164,  1.8008],
89         [-0.7692, -1.8623, -0.7085, -1.0100]])
90 ---------------------
91 func(x.mm(w1))
92 tensor([[-0.0000, 0.9918, -0.0000, 1.8008],
93         [-0.0000, -0.0000, -0.0000, -0.0000]])
94 ---------------------
95 w1.grad: tensor([[  0.0000,   0.0437,   0.0000,   0.1679],
96         [  0.0000,  -3.9346,   0.0000, -15.1145],
97         [  0.0000,  -0.2101,   0.0000,  -0.8070]])
98 ---------------------
99 ========================
View Code

對比發現,二者在梯度大小及更新的數值、loss大小等都是數值相等的,這是否說明對於不可導函數,直接定義函數也可以取得和正確定義前向后向過程相同的結果呢?

應當注意到一個問題,那就是在MyReLU.apply的實驗結果中,出現數值為0的地方,顯示為0.0000,而在no_back的實驗結果中,出現數值為0的地方,顯示為-0.0000;

0.0000與-0.0000有什么區別呢?

參考stack overflow中的解答:https://stackoverflow.com/questions/4083401/negative-zero-in-python

和wikipedia中對於signed zero的介紹:https://en.wikipedia.org/wiki/Signed_zero

在python中二者是顯然不同的對象,但是在數值比較時,二者的值顯示為相等。

-0.0 == +0.0 == 0

在Python 中使它們數值相等的設定,是在盡量避免為code引入bug.

>>> a = 3.4
>>> b =4.4
>>> c = -0.0
>>> d = +0.0
>>> a*c
-0.0
>>> b*d
0.0
>>> a*c == b*d
True
>>> 

雖然看起來,它們在使用中並沒有什么區別,但是在計算機內部對它們的編碼表示並不相同。

在對於整數的1+7位元的符號數值表示法中,負零是用二進制代碼10000000表示的。在8位元二進制反碼中,負零是用二進制代碼11111111表示,但補碼表示法則沒有負零的概念。在IEEE 754二進制浮點數算術標准中,指數和尾數為零、符號位元為一的數就是負零。

IBM的普通十進制算數編碼規范中,運用十進制來表示浮點數。這里負零被表示為指數為編碼內任意合法數值、所有系數均為零、符號位元為一的數。

 ~(wikipedia)

在數值分析中,也常將-0看做從負數區間無限趨近於0的值,將+0看做從正數區間無限趨近於0的值,二者在數值上近似相等,但在某些操作中卻可能產生不同的結果。

比如 divmod,會沿用數值的sign:

>>> divmod(-0.0,100)
(-0.0, 0.0)
>>> divmod(+0.0,100)
(0.0, 0.0)

比如 atan2, (介紹詳見https://en.wikipedia.org/wiki/Atan2)

 

atan2(+0, +0) = +0;  

atan2(+0, −0) = +π;  ( 當y是位於y軸正半軸,無限趨近於0的值;x是位於x軸負半軸,無限趨近於0的值,=> 可以看做是在第二象限中位於x軸負半軸的一點 => $\theta夾角為$\pi$)

atan2(−0, +0) = −0;  ( 可以看做是在第四象限中位於x軸正半軸的一點 => $\theta夾角為-0)

atan2(−0, −0) = −π.

用代碼驗證:

>>> math.atan2(0.0, 0.0) == math.atan2(-0.0, 0.0)
True 
>>> math.atan2(0.0, -0.0) == math.atan2(-0.0, -0.0)
False

所以,盡管在上面自定義激活函數時,將不可導函數強行加入到pytorch的autograd中運算,數值結果相同;但是注意到-0.0000的出現是程序有bug的提示,嚴謹考慮仍需要規范定義,如MyReLU。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM