[pytorch] 自定義激活函數swish(三)
在神經網絡模型中,激活函數多種多樣。大體都是,小於0的部分,進行抑制(即,激活函數輸出為非常小的數),大於0的部分,進行放大(即,激活函數輸出為較大的數)。
主流的激活函數一般都滿足,
1. 非線性。信號處理里,信號通過非線性系統后,能產生新頻率的信號。不妨假定,非線性有相似作用。
2. 可微性。可求導的,在反向傳播中,可以方便使用鏈式求導的。
3. 單調性。swish 激活函數在小於0的部分是非單調的。
為了測試不同激活函數,對神經網絡的影響。我們把之前寫的CNN模型中,激活函數抽取出來,獨立寫成一個接口。
由於pytorch集成了常用的激活函數,對於已經集成好的ReLU等函數。可以使用簡單的
def Act_op():
return nn.ReLU()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.con_layer1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
Act_op()
)
def forward(self, x):
x = self.con_layer1(x)
return x
對於Swish = x*sigmod(x) 這種pytorch還沒集成的函數,就需要自定義Act_op()。
方法一:使用nn.Function
## 由於 Function 可能需要暫存 input tensor。
## 因此,建議不復用 Function 對象,以避免遇到內存提前釋放的問題。
class Swish_act(torch.autograd.Function):
## save_for_backward can only!!!! save input or output tensors
@staticmethod
def forward(self, input_):
print('swish act op forward')
output = input_ * F.sigmoid(input_)
self.save_for_backward(input_)
return output
@staticmethod
def backward(self, grad_output):
## according to the chain rule(Backpropagation),
## d(loss)/d(x) = d(loss)/d(output) * d(output)/d(x)
## grad_output is the d(loss)/d(output)
## we calculate and save the d(output)/d(x) in forward
input_, = self.saved_tensors
output = input_ * F.sigmoid(input_)
grad_swish = output + F.sigmoid(input_) * (1 - output)
print('swish act op backward')
return grad_output * grad_swish
def Act_op():
return Swish_act()
在使用這種方法寫的時候,遇到了幾個坑。
首先,save_to_backward()只允許保存輸入、輸出的張量。比如,輸入為a(即,forward(self, a)),那么save_to_backward(a)沒問題,save_to_backward(a+1)報錯。
其次,根據pytorch的邏輯,nn.Function是在nn.Module的forward()過程中使用,不能在__init__中使用。
其三,如果在模型中,需要重復調用這個Swish_act()接口。會出現前一次使用的內存被提前釋放掉,使得在反向傳播計算中,需要使用的變量失效。
為了解決不能重復調用問題,可以使用nn.Module,創建CNN網絡模型的Module子類。
方法二:
class Act_op(nn.Module):
def __init__(self):
super(Act_op, self).__init__()
def forward(self, x):
x = x * F.sigmoid(x)
return x
簡單、快捷、方便。還不用自己寫backward。而且,不需要對之前CNN模型里,class Net(nn.Module)做任何改動。
---------------------
作者:cling-L
來源:CSDN
原文:https://blog.csdn.net/lingdexixixi/article/details/79796605
版權聲明:本文為博主原創文章,轉載請附上博文鏈接!