深度殘差收縮網絡 Deep Residual Shrinkage Networks for Fault Diagnosis (原文翻譯)


深度殘差收縮網絡是深度殘差網絡的一種改進,針對的是數據中含有噪聲或冗余信息的情況,將軟閾值函數引入深度殘差網絡的內部,通過消除冗余特征,增強高層特征的判別性。其核心部分就是下圖所示的基本模塊:

以下對部分原文進行了翻譯,僅以學習為目的。

【題目】Deep Residual Shrinkage Networks for Fault Diagnosis

【翻譯】基於深度殘差收縮網絡的故障診斷

Abstract (摘要)

Abstract: This paper develops new deep learning methods, namely, deep residual shrinkage networks, to improve the feature learning ability from highly noised vibration signals and achieve a high fault diagnosing accuracy. Soft thresholding is inserted as nonlinear transformation layers into the deep architectures to eliminate unimportant features. Moreover, considering that it is generally challenging to set proper values for the thresholds, the developed deep residual shrinkage networks integrate a few specialized neural networks as trainable modules to automatically determine the thresholds, so that professional expertise on signal processing is not required. The efficacy of the developed methods is validated through experiments with various types of noise.

摘要:本文提出了一種新的深度學習方法,名為深度殘差收縮網絡,來提高深度學習方法從強噪聲信號中學習特征的能力,並且取得較高的故障診斷准確率。軟閾值化作為非線性層,嵌入到深度神經網絡之中,以消除不重要的特征。更進一步地,考慮到軟閾值化中的閾值是難以設定的,本文所提出的深度殘差收縮網絡,采用了一個子網絡,來自動地設置這些閾值,從而回避了信號處理領域的專業知識。該方法的有效性通過多種不同噪聲下的實驗進行了驗證。

【關鍵詞】Deep learning, deep residual networks, fault diagnosis, soft thresholding, vibration signal.

【翻譯】深度學習,深度殘差網絡,故障診斷,軟閾值化,振動信號。

I. Introduction (引言) 

【翻譯】旋轉機械在制造業、電力供應、運輸業和航天工業都是很重要的。然而,因為這些旋轉機械工作在嚴酷的工作環境下,其機械傳動系統不可避免地會遭遇一些故障,並且會導致事故和經濟損失。准確的機械傳動系統故障診斷,能夠用來安排維修計划、延長服役壽命和確保人身安全。

【翻譯】現有的機械傳動系統故障診斷算法可分為兩類,一類是基於信號分析的方法,另一類是基於機器學習的方法。通常,基於信號分析的故障診斷方法通過檢測故障相關的振動成分或者特征頻率,來確定故障類型。然而,對於大型旋轉機械,其振動信號往往是由許多不同的振動信號混疊而成的,包括齒輪的嚙合頻率、軸和軸承的旋轉頻率等。更重要地,當故障處於早期階段的時候,故障相關的振動成分往往是比較微弱的,容易被其他的振動成分和諧波所淹沒。總而言之,傳統基於信號分析的故障診斷方法經常難以檢測到故障相關的振動成分和特征頻率。

【翻譯】從另一方面來講,基於機器學習的故障診斷方法,在診斷故障的時候不需要確定故障相關的成分和特征頻率。首先,一組統計特征(例如峭度、均方根值、能量、熵)能夠被提取來表征健康狀態;然后一個分類器(例如多分類支持向量機、單隱含層的神經網絡、朴素貝葉斯分類器)能夠被訓練來診斷故障。然而,所提取的統計特征經常是判別性不足的,難以區分故障,從而導致了低的診斷准確率。因此,尋找一個判別性強的特征集,是基於機器學習的故障診斷中一個長期的挑戰。

【翻譯】近年來,深度學習方法,即有多個非線性映射層的機器學習方法,成為了基於振動信號進行故障診斷的有力工具。深度學習方法能夠自動地從原始振動數據中學習特征,以取代傳統的統計特征,來獲得高的診斷准確率。例如,Ince等人采用一維卷積神經網絡,從電流信號中學習特征,應用於實時電機故障診斷。Shao等人采用一種卷積深度置信網絡,應用於電機軸承的故障診斷。但是,一個問題是,誤差函數的梯度,在逐層反向傳播的過程中,逐漸變得不准確。因此,在輸入層附近的一些層的參數不能夠被很好地優化。

【翻譯】深度殘差網絡是卷積神經網絡的一個新穎的變種,采用了恆等路徑來減輕參數優化的難度。在深度殘差網絡中,梯度不僅逐層地反向傳播,而且通過恆等路徑直接傳遞到之前的層。由於優越的參數優化能力,深度殘差網絡在最近的一些研究中,已經被應用於故障診斷。例如,Ma等人將一種集成了解調時頻特征的深度殘差網絡,應用於不穩定工況下的行星齒輪箱故障診斷。Zhao等人使用深度殘差網絡,來融合多組小波包系數,應用於故障診斷。相較於普通的卷積神經網絡,深度殘差網絡的優勢已經在這些論文中得到了驗證。

【翻譯】從大型旋轉機械(例如風電、機床、重型卡車)所采集的振動信號,經常包含着大量的噪聲。在處理強噪聲振動信號的時候,深度殘差網絡的特征學習能力經常會降低。深度殘差網絡中的卷積核,其實就是濾波器,在噪聲的干擾下,可能不能檢測到故障特征。在這種情況下,在輸出層所學習到的高層特征,就會判別性不足,不能夠准確地進行故障分類。因此,開發新的深度學習方法,應用於強噪聲下旋轉機械的故障診斷,是十分必要的。

【翻譯】本文提出了兩種深度殘差收縮網絡,即通道間共享閾值的深度殘差收縮網絡、通道間不同閾值的深度殘差收縮網絡,來提高從強噪聲振動信號中學習特征的能力,最終提高故障診斷准確率。本文的主要貢獻總結如下:

(1) 軟閾值化(也就是一種流行的收縮方程)作為非線性層,被嵌入深度結構之中,以有效地消除噪聲相關的特征。

(2) 采用特殊設計的子網絡,來自適應地設置閾值,從而每段振動信號都有着自己獨特的一組閾值。

(3) 在軟閾值化中,共考慮了兩種閾值,也就是通道間共享的閾值、通道間不同的閾值。這也是所提出方法名稱的由來。

【翻譯】本文的剩余部分安排如下。第二部分簡要地回顧了經典的深度殘差網絡,並且詳細闡述了所提出的深度殘差收縮網絡。第三部分進行了實驗對比。第四部分進行了總結。

II. Theory of the developed DRSNs (深度殘差收縮網絡的理論)

 

【翻譯】如第一部分所述,作為一種潛在的、能夠從強噪聲振動信號中學習判別性特征的方法,本研究考慮了深度學習和軟閾值化的集成。相對應地,本部分注重於開發深度殘差網絡的兩個改進的變種,即通道間共享閾值的深度殘差收縮網絡、通道間不同閾值的深度殘差收縮網絡。對相關理論背景和必要的想法進行了詳細介紹。

A. Basic Components (基本組成)

【翻譯】不管是深度殘差網絡,還是所提出的深度殘差收縮網絡,都有一些基礎的組成,是和傳統卷積神經網絡相同的,包括卷積層、整流線性單元激活函數、批標准化、全局均值池化、交叉熵誤差函數。這些基礎組成的概念在下面進行了介紹。

【翻譯】卷積層是使得卷積神經網絡不同於傳統全連接神經網絡的關鍵。卷積層能夠大量減少所需要訓練的參數的數量。這是通過用卷積,取代乘法矩陣,來實現的。卷積核中的參數,比全連接層中的權重,少得多。更進一步地,當參數較少時,深度學習不容易遭遇過擬合,從而能夠在測試集上獲得較高的准確率。輸入特征圖和卷積核之間的卷積運算,附帶着加上偏置,能夠用公式表示為…。卷積可以通過重復一定次數,來獲得輸出特征圖。

【翻譯】圖1展示了卷積的過程。如圖1(a)-(b)所示,特征圖和卷積核實際上是三維張量。在本文中,一維振動信號是輸入,所以特征圖和卷積核的高度始終是1。如圖1(c)所示,卷積核在輸入特征圖上滑動,從而得到輸出特征圖的一個通道。在每個卷積層中,通常有多於一個卷積核,從而輸出特征圖有多個通道。

【翻譯】圖1 (a) 特征圖,(b) 卷積核和(c)卷積過程示意圖

 

【翻譯】批標准化是一種嵌入到深度結構的內部、作為可訓練層的一種特征標准化方法。批標准化的目的在於減輕內部協方差漂移的問題,即特征的分布經常在訓練過程中持續變化。在這種情況下,所需訓練的參數就要不斷地適應變化的特征分布,從而增大了訓練的難度。批標准化,在第一步對特征進行標准化,來獲得一個固定的分布,然后在訓練過程中自適應地調整這個分布。后續介紹公式。

【翻譯】激活函數通常是神經網絡中必不可少的一部分,一般是用來實現非線性變換的。在過去的幾十年中,很多種激活函數被提出來,例如sigmoid,tanh和ReLU。其中,ReLU激活函數最近得到了很多關注,這是因為ReLU能夠很有效地避免梯度消失的問題。ReLU激活函數的導數要么是1,要么是0,能夠幫助控制特征的取值范圍大致不變,在特征在層間傳遞的時候。ReLU的函數表達式為max(x,0)。

【翻譯】 全局均值池化是從特征圖的每個通道計算一個平均值的運算。通常,全局均值池化是在最終輸出層之前使用的。全局均值池化可以減少全連接輸出層的權重數量,從而降低深度神經網絡遭遇過擬合的風險。全局均值池化還可以解決平移變化問題,從而深度神經網絡所學習得到的特征,不會受到故障沖擊位置變化的影響。 

【翻譯】交叉熵損失函數通常作為多分類問題的目標函數,朝着最小的方向進行優化。相較於傳統的均方差損失函數,交叉熵損失函數經常能夠提供更快的訓練速度。這是因為,交叉熵損失函數對於權重的梯度,相較於均方差損失函數,不容易減弱到零。為了計算交叉熵損失函數,首先要用softmax函數將特征轉換到零一區間。然后交叉熵損失函數可以根據公式進行計算。在獲得交叉熵損失函數之后,梯度下降法可以用來優化參數。在一定的迭代次數之后,深度神經網絡就能夠得到充分的訓練。

B. Architecture of the Classical ResNet (經典深度殘差網絡的結構)

【翻譯】深度殘差網絡是一種新興的深度學習方法,在近年來受到了廣泛的關注。殘差構建模塊是基本的組成部分。如圖2a所示,殘差構建模塊包含了兩個批標准化、兩個整流線性單元、兩個卷積層和一個恆等路徑。恆等路徑是讓深度殘差網絡優於卷積神經網絡的關鍵。交叉熵損失函數的梯度,在普通的卷積神經網絡中,是逐層反向傳播的。當使用恆等路徑的時候,梯度能夠更有效地流回前面的層,從而參數能夠得到更有效的更新。

圖2b-2c展示了兩種殘差構建模塊,能夠輸出不同尺寸的特征圖。在這里,減小輸出特征圖尺寸的原因在於,減小后續層的運算量;增加通道數的原因在於,方便將不同的特征集成為強判別性的特征。

圖2d展示了深度殘差網絡的整體框架,包括一個輸入層、一個卷積層、一定數量的殘差構建模塊、一個批標准化、一個ReLU激活函數、一個全局均值池化和一個全連接輸出層。同時,深度殘差網絡作為本研究的基准,以求進一步改進。

【翻譯】圖2 3種殘差構建模塊:(a) 輸入特征圖的尺寸=輸出特征圖的尺寸,(b)輸出特征圖的寬度減半,(c)輸出特征圖的寬度減半、通道數翻倍。(d)深度殘差網絡的整體框架。

C. Design of Fundamental Architectures for DRSNs (深度殘差收縮網絡基本結構的設計)

【翻譯】這一小節首先介紹了提出深度殘差收縮網絡的原始驅動,然后詳細介紹了所提出深度殘差收縮網絡的結構。

1) Theoretical background (理論背景)

【翻譯】在過去的20年中,軟閾值化經常被作為許多信號降噪算法中的關鍵步驟。通常,信號被轉換到一個域。在這個域中,接近零的特征,是不重要的。然后,軟閾值化將這些接近於零的特征置為零。例如,作為一種經典的信號降噪算法,小波閾值化通常包括三個步驟:小波分解、軟閾值化和小波重構。為了保證信號降噪的效果,小波閾值化的一個關鍵任務是設計一個濾波器。這個濾波器能夠將有用的信息轉換成比較大的特征,將噪聲相關的信息轉換成接近於零的特征。然而,設計這樣的濾波器需要大量的信號處理方面的專業知識,經常是非常困難的。深度學習提供了一種解決這個問題的新思路。這些濾波器可以通過反向傳播算法自動優化得到,而不是由專家進行設計。因此,軟閾值化和深度學習的結合是一種有效地消除噪聲信息和構建高判別性特征的方式。軟閾值化將接近於零的特征直接置為零,而不是像ReLU那樣,將負的特征置為零,所以負的、有用的特征能夠被保留下來。

【翻譯】軟閾值化的過程如圖3(a)所示。可以看出,軟閾值化的輸出對於輸入的導數要么是1,要么是0,所以在避免梯度消失和梯度爆炸的問題上,也是很有效的。

【翻譯】圖3 (a)軟閾值化,(b)它的偏導

【翻譯】在傳統的信號降噪算法中,經常難以給閾值設置一個合適的值。同時,對於不同的樣本,最優的閾值往往是不同的。針對這個問題,深度殘差收縮網絡的閾值,是在深度網絡中自動確定的,從而避免了人工的操作。深度殘差收縮網絡中,這種設置閾值的方式,在后續文中進行了介紹。

2) Architecture of the Developed DRSN-CS (通道間共享閾值的深度殘差收縮網絡結構)

所提出的通道間共享閾值的深度殘差收縮網絡,是深度殘差網絡的一個變種,使用了軟閾值化來消除與噪聲相關的特征。軟閾值化作為非線性層嵌入到殘差構建模塊之中。更重要地,閾值是在殘差構建模塊中自動學習得到的,介紹如下。

【翻譯】圖4 (a)通道間共享閾值的殘差模塊,(b)通道間共享閾值的深度殘差收縮網絡,(c)通道間不同閾值的殘差模塊,(d) 通道間不同閾值的深度殘差收縮網絡

【翻譯】如圖4(a)所示,名為“通道間共享閾值的殘差收縮構建模塊”,與圖2(a)中殘差構建模塊是不同的,有一個特殊模塊來估計軟閾值化所需要的閾值。在這個特殊模塊中,全局均值池化被應用在特征圖的絕對值上面,來獲得一維向量。然后,這個一維向量被輸入到一個兩層的全連接網絡中,來獲得一個尺度化參數。Sigmoid函數將這個尺度化參數規整到零和一之間。然后,這個尺度化參數,乘以特征圖的絕對值得平均值,作為閾值。這樣的話,就可以把閾值控制在一個合適的范圍內,不會使輸出特征全部為零。

【翻譯】所提出的通道間共享閾值的深度殘差收縮網絡的結構簡圖如圖4(b)所示,和圖2(d)中經典深度殘差網絡是相似的。唯一的區別在於,通道間共享閾值的殘差收縮模塊(RSBU-CS),替換了普通的殘差構建模塊。一定數量的RSBU-CS被堆疊起來,從而噪聲相關的特征被逐漸削減。另一個優勢在於,閾值是自動學習得到的,而不是由專家手工設置的,所以在實施通道間共享閾值的深度殘差收縮網絡的時候,不需要信號處理領域的專業知識。

3) Architecture of the developed DRSN-CW (通道間不同閾值的深度殘差收縮網絡結構)

【翻譯】道間不同閾值的深度殘差收縮網絡,是深度殘差網絡的另一個變種。與通道間共享閾值的深度殘差收縮網絡的區別在於,特征圖的每個通道有着自己獨立的閾值。通道間不同閾值的殘差模塊如圖4(c)所示。特征圖x首先被壓縮成了一個一維向量,並且輸入到一個兩層的全連接層中。全連接層的第二層有多於一個神經元,並且神經元的個數等於輸入特征圖的通道數。全連接層的輸出被強制到零和一之間。之后計算出閾值。與通道間共享閾值的深度殘差收縮網絡相似,閾值始終是正數,並且被保持在一個合理范圍內,從而防止輸出特征都是零的情況。

【翻譯】通道間不同閾值的深度殘差收縮網絡的整體框架如圖4(d)所示。一定數量的模塊被堆積起來,從而判別性特征能夠被學習得到。其中,軟閾值化,作為收縮函數,用於非線性變換,來消除噪聲相關的信息。

Reference:

M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

https://ieeexplore.ieee.org/document/8850096

轉載:https://blog.csdn.net/ohangzi/article/details/103617702

參考鏈接:

秒懂深度殘差收縮網絡 https://www.jianshu.com/p/90f1ef1b06bc

深度殘差收縮網絡:(一)背景知識 https://www.cnblogs.com/yc-9527/p/11598844.html

深度殘差收縮網絡:(二)整體思路 https://www.cnblogs.com/yc-9527/p/11601322.html

深度殘差收縮網絡:(三)網絡結構 https://www.cnblogs.com/yc-9527/p/11603320.html

深度殘差收縮網絡:(四)注意力機制下的閾值設置 https://www.cnblogs.com/yc-9527/p/11604082.html

深度殘差收縮網絡:(五)實驗驗證 https://www.cnblogs.com/yc-9527/p/11610073.html

[論文筆記] 深度殘差收縮網絡 https://zhuanlan.zhihu.com/p/85238942

 

Keras示例代碼

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""

from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)

# Input image dimensions
img_rows, img_cols = 28, 28

# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)


def abs_backend(inputs):
    return K.abs(inputs)

def expand_dim_backend(inputs):
    return K.expand_dims(K.expand_dims(inputs,1),1)

def sign_backend(inputs):
    return K.sign(inputs)

def pad_backend(inputs, in_channels, out_channels):
    pad_dim = (out_channels - in_channels)//2
    inputs = K.expand_dims(inputs,-1)
    inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
    return K.squeeze(inputs, -1)

# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                             downsample_strides=2):
    
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
    
    for i in range(nb_blocks):
        
        identity = residual
        
        if not downsample:
            downsample_strides = 1
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), 
                          padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        # Calculate global means
        residual_abs = Lambda(abs_backend)(residual)
        abs_mean = GlobalAveragePooling2D()(residual_abs)
        
        # Calculate scaling coefficients
        scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', 
                       kernel_regularizer=l2(1e-4))(abs_mean)
        scales = BatchNormalization()(scales)
        scales = Activation('relu')(scales)
        scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
        scales = Lambda(expand_dim_backend)(scales)
        
        # Calculate thresholds
        thres = keras.layers.multiply([abs_mean, scales])
        
        # Soft thresholding
        sub = keras.layers.subtract([residual_abs, thres])
        zeros = keras.layers.subtract([sub, sub])
        n_sub = keras.layers.maximum([sub, zeros])
        residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
        
        # Downsampling (it is important to use the pooL-size of (1, 1))
        if downsample_strides > 1:
            identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
            
        # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution)
        if in_channels != out_channels:
            identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
        
        residual = keras.layers.add([residual, identity])
    
    return residual


# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))

# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])

 

TFLearn示例代碼

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
 
@author: super_9527
"""
  
from __future__ import division, print_function, absolute_import
  
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
  
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
  
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
  
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
  
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
      
    # residual shrinkage blocks with channel-wise thresholds
  
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
  
    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
  
    with vscope as scope:
        name = scope.name #TODO
  
        for i in range(nb_blocks):
  
            identity = residual
  
            if not downsample:
                downsample_strides = 1
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
              
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            # soft thresholding
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
              
  
            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)
  
            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels
  
            residual = residual + identity
  
    return residual
  
  
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
  
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
  
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)
  
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
  
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM