GRU(Gated Recurrent Unit) 更新過程推導及簡單代碼實現
RNN網絡考慮到了具有時間數列的樣本數據,但是RNN仍存在着一些問題,比如隨着時間的推移,RNN單元就失去了對很久之前信息的保存和處理的能力,而且存在着gradient vanishing問題。
所以有些特殊類型的RNN網絡相繼被提出,比如LSTM(long short term memory)和GRU(gated recurrent unit)(Chao,et al. 2014).這里我主要推導一下GRU參數的迭代過程
GRU單元結構如下圖所示

1479126283494.jpg
數據流過程如下

其中表示Hadamard積,即對應元素乘積;下標表示節點的index,上標表示時刻;
表示隱層到輸出層的參數矩陣,
分別是隱層和輸出層的節點個數;
分別表示輸入和上一時刻隱層到更新門z的連接矩陣,
表示輸入數據的維度;
分別表示輸入和上一時刻隱層到重置門r的連接矩陣;
分別表示輸入和上一時刻的隱層到待選狀態
的連接矩陣。
針對於時刻t,使用鏈式求導法則,計算參數矩陣的梯度,其中E是代價函數,首先計算對隱層輸出的梯度,因為隱層輸出牽涉到多個時刻

所以

其中分別是對應激活函數的線性和部分
現在對參數計算梯度

令

則

將上面的式子矢量化(行向量)表示:


那接下來使用matlab來實現一個小例子,看看GRU的效果,同樣是二進制相加的問題
- function error= GRUtest( )
- % 初始化訓練數據
- uNum=16;%單元個數
- maxInt=2^uNum;
- % 初始化網絡結構
- xdim=2;
- ydim=1;
- hdim=16;
- eta=0.1;
- %初始化網絡參數
- Wy=rand(hdim,ydim)*2-1;
- Wr=rand(xdim,hdim)*2-1;
- Ur=rand(hdim,hdim)*2-1;
- W =rand(xdim,hdim)*2-1;
- U =rand(hdim,hdim)*2-1;
- Wz=rand(xdim,hdim)*2-1;
- Uz=rand(hdim,hdim)*2-1;
-
- rvalues=zeros(uNum+1,hdim);
- zvalues=zeros(uNum+1,hdim);
- hbarvalues=zeros(uNum,hdim);
- hvalues = zeros(uNum,hdim);
- yvalues=zeros(uNum,ydim);
-
- for p=1:10000
- aInt=randi(maxInt/2);
- bInt=randi(maxInt/2);
- cInt=aInt+bInt;
- at=dec2bin(aInt)-'0';
- bt=dec2bin(bInt)-'0';
- ct=dec2bin(cInt)-'0';
- a=zeros(1,uNum);
- b=zeros(1,uNum);
- c=zeros(1,uNum);
- a(1:size(at,2))=at(end:-1:1);
- b(1:size(bt,2))=bt(end:-1:1);
- c(1:size(ct,2))=ct(end:-1:1);
- xvalues=[a;b]';
- d=c';
-
- % 前向計算
- rvalues(1,:)=sigmoid(xvalues(1,:)*Wr);
- hbarvalues(1,:)=outTanh(xvalues(1,:)*W);
- zvalues(1,:)=sigmoid(xvalues(1,:)*Wz);
- hvalues(1,:)=zvalues(1,:).*hbarvalues(1,:);
- yvalues(1,:)=sigmoid(hvalues(1,:)*Wy);
- for t=2:uNum
- rvalues(t,:)=sigmoid(xvalues(t,:)*Wr+hvalues(t-1,:)*Ur);
- hbarvalues(t,:)=outTanh(xvalues(t,:)*W+(rvalues(t,:).*hvalues(t-1,:))*U);
- zvalues(t,:)=sigmoid(xvalues(t,:)*Wz+hvalues(t-1,:)*Uz);
- hvalues(t,:)=(1-zvalues(t,:)).*hvalues(t-1,:)+zvalues(t,:).*hbarvalues(t,:);
- yvalues(t,:)=sigmoid(hvalues(t,:)*Wy);
- end
-
- % 誤差反向傳播
- delta_r_next=zeros(1,hdim);
- delta_z_next=zeros(1,hdim);
- delta_h_next=zeros(1,hdim);
- delta_next=zeros(1,hdim);
-
- dWy=zeros(hdim,ydim);
- dWr=zeros(xdim,hdim);
- dUr=zeros(hdim,hdim);
- dW=zeros(xdim,hdim);
- dU=zeros(hdim,hdim);
- dWz=zeros(xdim,hdim);
- dUz=zeros(hdim,hdim);
-
- for t=uNum:-1:2
- delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:));
- delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:));
- delta_z=delta_h.*(hbarvalues(t,:)-hvalues(t-1,:)).*diffsigmoid(zvalues(t,:));
- delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:));
- delta_r=hvalues(t-1,:).*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:));
-
- dWy=dWy+hvalues(t,:)'*delta_y;
- dWz=dWz+xvalues(t,:)'*delta_z;
- dUz=dUz+hvalues(t-1,:)'*delta_z;
- dW =dW+xvalues(t,:)'*delta;
- dU =dU+(rvalues(t,:).*hvalues(t-1,:))'*delta ;
- dWr=dWr+xvalues(t,:)'*delta_r;
- dUr=dUr+hvalues(t-1,:)'*delta_r;
-
- delta_r_next=delta_r;
- delta_z_next=delta_z;
- delta_h_next=delta_h;
- delta_next =delta;
-
- end
-
- t=1;
- delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:));
- delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:));
- delta_z=delta_h.*(hbarvalues(t,:)-0).*diffsigmoid(zvalues(t,:));
- delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:));
- delta_r=0.*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:));
-
- dWy=dWy+hvalues(t,:)'*delta_y;
- dWz=dWz+xvalues(t,:)'*delta_z;
- dW =dW+xvalues(t,:)'*delta;
- dWr=dWr+xvalues(t,:)'*delta_r;
-
- Wy = Wy-eta*dWy;
- Wr = Wr-eta*dWr;
- Ur = Ur-eta*dUr;
- W = W -eta*dW;
- U = U-eta*dU;
- Wz = Wz-eta*dWz;
- Uz = Uz-eta*dUz;
- error = (norm(yvalues-d,2))/2.0;
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- if mod(p,500)==0
- fprintf('******************第%s次迭代****************\n',int2str(p));
- yvalues=round(yvalues(end:-1:1));
- y=bin2dec(int2str(yvalues'));
- fprintf('y=%d\n',y);
- fprintf('c=%d\n',cInt);
- fprintf('樣本誤差:e=%f\n',error);
- end
- end
- end
-
- function f=sigmoid(x)
- f=1./(1+exp(-x));
- end
-
- function fd = diffsigmoid(f)
- fd=f.*(1-f);
- end
-
- function g=outTanh(x)
- g=1-2./(1+exp(2*x));
- end
-
- function gd=diffoutTanh(g)
- gd=1-g.^2;
- end
部分實驗結果

1479392393541.jpg