[pytorch] 自定義激活函數swish(三)


[pytorch] 自定義激活函數swish(三)

 

在神經網絡模型中,激活函數多種多樣。大體都是,小於0的部分,進行抑制(即,激活函數輸出為非常小的數),大於0的部分,進行放大(即,激活函數輸出為較大的數)。

         主流的激活函數一般都滿足,

         1. 非線性。信號處理里,信號通過非線性系統后,能產生新頻率的信號。不妨假定,非線性有相似作用。

         2. 可微性。可求導的,在反向傳播中,可以方便使用鏈式求導的。

         3. 單調性。swish 激活函數在小於0的部分是非單調的。

         為了測試不同激活函數,對神經網絡的影響。我們把之前寫的CNN模型中,激活函數抽取出來,獨立寫成一個接口。

         由於pytorch集成了常用的激活函數,對於已經集成好的ReLU等函數。可以使用簡單的

def Act_op():
return nn.ReLU()

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.con_layer1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
Act_op()
)
def forward(self, x):
x = self.con_layer1(x)
return x
        對於Swish = x*sigmod(x) 這種pytorch還沒集成的函數,就需要自定義Act_op()。

方法一:使用nn.Function

## 由於 Function 可能需要暫存 input tensor。
## 因此,建議不復用 Function 對象,以避免遇到內存提前釋放的問題。
class Swish_act(torch.autograd.Function):
## save_for_backward can only!!!! save input or output tensors
@staticmethod
def forward(self, input_):
print('swish act op forward')
output = input_ * F.sigmoid(input_)
self.save_for_backward(input_)
return output

@staticmethod
def backward(self, grad_output):
## according to the chain rule(Backpropagation),
## d(loss)/d(x) = d(loss)/d(output) * d(output)/d(x)
## grad_output is the d(loss)/d(output)
## we calculate and save the d(output)/d(x) in forward
input_, = self.saved_tensors
output = input_ * F.sigmoid(input_)
grad_swish = output + F.sigmoid(input_) * (1 - output)
print('swish act op backward')
return grad_output * grad_swish

def Act_op():
return Swish_act()

在使用這種方法寫的時候,遇到了幾個坑。

首先,save_to_backward()只允許保存輸入、輸出的張量。比如,輸入為a(即,forward(self, a)),那么save_to_backward(a)沒問題,save_to_backward(a+1)報錯。

其次,根據pytorch的邏輯,nn.Function是在nn.Module的forward()過程中使用,不能在__init__中使用。

其三,如果在模型中,需要重復調用這個Swish_act()接口。會出現前一次使用的內存被提前釋放掉,使得在反向傳播計算中,需要使用的變量失效。

 

為了解決不能重復調用問題,可以使用nn.Module,創建CNN網絡模型的Module子類。

方法二:

class Act_op(nn.Module):
def __init__(self):
super(Act_op, self).__init__()

def forward(self, x):
x = x * F.sigmoid(x)
return x
     簡單、快捷、方便。還不用自己寫backward。而且,不需要對之前CNN模型里,class Net(nn.Module)做任何改動。
---------------------
作者:cling-L
來源:CSDN
原文:https://blog.csdn.net/lingdexixixi/article/details/79796605
版權聲明:本文為博主原創文章,轉載請附上博文鏈接!

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM