顧名思義,深度殘差收縮網絡是由“殘差網絡”和“收縮”兩個部分所組成的,是“殘差網絡”的一種改進算法。
其中,殘差網絡在2016年獲得了ImageNet圖像識別競賽的冠軍,目前已成為深度學習領域的基礎網絡;“收縮”就是“軟閾值化”,是許多信號降噪方法的核心步驟。
深度殘差收縮網絡也是一種“注意力機制”下的深度學習算法。其軟閾值化所需要的閾值,本質上是在注意力機制下設置的。
在本文中,我們首先對殘差網絡、軟閾值化和注意力機制的相關基礎進行了簡要的回顧,然后對深度殘差收縮網絡的動機、算法和應用展開解讀。
1. 相關基礎
1.1 殘差網絡
殘差網絡(或稱深度殘差網絡、深度殘差學習,英文ResNet)屬於一種卷積神經網絡。相較於普通的卷積神經網絡,殘差網絡采用了跨層恆等連接,以減輕卷積神經網絡的訓練難度。殘差網絡的一種基本模塊如圖1所示。
1.2 軟閾值化
軟閾值化是許多信號降噪方法的核心步驟。它的用處是將絕對值低於某個閾值的特征置為零,將其他的特征也朝着零進行調整,也就是收縮。在這里,閾值是一個需要預先設置的參數,其取值大小對於降噪的結果有着直接的影響。軟閾值化的輸入與輸出之間的關系如圖2所示。
從圖2可以看出,軟閾值化是一種非線性變換,有着與ReLU激活函數非常相似的性質:梯度要么是0,要么是1。因此,軟閾值化也能夠作為神經網絡的激活函數。事實上,一些神經網絡已經將軟閾值函數作為激活函數進行了使用。
1.3 注意力機制
注意力機制就是將注意力集中於局部關鍵信息的機制,可以分成兩步:第一,通過全局掃描,發現局部有用信息;第二,增強有用信息並抑制冗余信息。
Squeeze-and-Excitation Network是一種非常經典的注意力機制下的深度學習方法。它可以通過一個小型的子網絡,自動學習得到一組權重,對特征圖的各個通道進行加權。其含義在於,某些特征通道是較為重要的,而另一些特征通道是信息冗余的;那么,我們就可以通過這種方式增強有用特征通道、削弱冗余特征通道。Squeeze-and-Excitation Network的一種基本模塊如下圖所示。
值得指出的是,通過這種方式,每個樣本都可以有自己獨特的一組權重,可以根據樣本自身的特點,進行獨特的特征通道加權調整。例如,樣本A的第一特征通道是重要的,第二特征通道是不重要的;而樣本B的第一特征通道是不重要的,第二特征通道是重要的;通過這種方式,樣本A可以有自己的一組權重,以加強第一特征通道,削弱第二特征通道;同樣地,樣本B可以有自己的一組權重,以削弱第一特征通道,加強第二特征通道。
2. 深度殘差收縮網絡理論
2.1 動機
首先,現實世界中的數據,或多或少都含有一些冗余信息。那么我們就可以嘗試將軟閾值化嵌入殘差網絡中,以進行冗余信息的消除。
其次,各個樣本中冗余信息含量經常是不同的。那么我們就可以借助注意力機制,根據各個樣本的情況,自適應地給各個樣本設置不同的閾值。
2.2 算法
與殘差網絡和Squeeze-and-Excitation Network相似,深度殘差收縮網絡也是由許多基本模塊堆疊而成的。每個基本模塊都有一個子網絡,用於自動學習得到一組閾值,用於特征圖的軟閾值化。值得指出的是,通過這種方式,每個樣本都有着自己獨特的一組閾值。深度殘差收縮網絡的一種基本模塊如下圖所示。
深度殘差收縮網絡的整體結構如下圖所示,是由輸入層、許多基本模塊以及最后的全連接輸出層等組成的。
2.3 應用
在論文中,深度殘差收縮網絡是應用於基於振動信號的旋轉機械故障診斷。但是從原理上來講,深度殘差收縮網絡面向的是數據集含有冗余信息的情況,而冗余信息是無處不在的。例如,在圖像識別的時候,圖像中總會包含一些與標簽無關的區域;在語音識別的時候,音頻中經常會含有各種形式的噪聲。因此,深度殘差收縮網絡,或者說這種“注意力機制”+“軟閾值化”的思路,有着較為廣泛的研究價值和應用前景。
Keras代碼
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sat Dec 28 23:24:05 2019 Implemented using TensorFlow 1.0.1 and Keras 2.2.1 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: super_9527 """ from __future__ import print_function import keras import numpy as np from keras.datasets import mnist from keras.layers import Dense, Conv2D, BatchNormalization, Activation from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D from keras.optimizers import Adam from keras.regularizers import l2 from keras import backend as K from keras.models import Model from keras.layers.core import Lambda K.set_learning_phase(1) # Input image dimensions img_rows, img_cols = 28, 28 # The data, split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) # Noised data x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1]) x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1]) print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) def abs_backend(inputs): return K.abs(inputs) def expand_dim_backend(inputs): return K.expand_dims(K.expand_dims(inputs,1),1) def sign_backend(inputs): return K.sign(inputs) def pad_backend(inputs, in_channels, out_channels): pad_dim = (out_channels - in_channels)//2 inputs = K.expand_dims(inputs,-1) inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last') return K.squeeze(inputs, -1) # Residual Shrinakge Block def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2): residual = incoming in_channels = incoming.get_shape().as_list()[-1] for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) # Calculate global means residual_abs = Lambda(abs_backend)(residual) abs_mean = GlobalAveragePooling2D()(residual_abs) # Calculate scaling coefficients scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(abs_mean) scales = BatchNormalization()(scales) scales = Activation('relu')(scales) scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales) scales = Lambda(expand_dim_backend)(scales) # Calculate thresholds thres = keras.layers.multiply([abs_mean, scales]) # Soft thresholding sub = keras.layers.subtract([residual_abs, thres]) zeros = keras.layers.subtract([sub, sub]) n_sub = keras.layers.maximum([sub, zeros]) residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub]) # Downsampling (it is important to use the pooL-size of (1, 1)) if downsample_strides > 1: identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity) # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution) if in_channels != out_channels: identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity) residual = keras.layers.add([residual, identity]) return residual # define and train a model inputs = Input(shape=input_shape) net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs) net = residual_shrinkage_block(net, 1, 8, downsample=True) net = BatchNormalization()(net) net = Activation('relu')(net) net = GlobalAveragePooling2D()(net) outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net) model = Model(inputs=inputs, outputs=outputs) model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test)) # get results K.set_learning_phase(0) DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0) print('Train loss:', DRSN_train_score[0]) print('Train accuracy:', DRSN_train_score[1]) DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0) print('Test loss:', DRSN_test_score[0]) print('Test accuracy:', DRSN_test_score[1])
TFLearn代碼
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Dec 23 21:23:09 2019 Implemented using TensorFlow 1.0 and TFLearn 0.3.2 M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: super_9527 """ from __future__ import division, print_function, absolute_import import tflearn import numpy as np import tensorflow as tf from tflearn.layers.conv import conv_2d # Data loading from tflearn.datasets import cifar10 (X, Y), (testX, testY) = cifar10.load_data() # Add noise X = X + np.random.random((50000, 32, 32, 3))*0.1 testX = testX + np.random.random((10000, 32, 32, 3))*0.1 # Transform labels to one-hot format Y = tflearn.data_utils.to_categorical(Y,10) testY = tflearn.data_utils.to_categorical(testY,10) def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2, activation='relu', batch_norm=True, bias=True, weights_init='variance_scaling', bias_init='zeros', regularizer='L2', weight_decay=0.0001, trainable=True, restore=True, reuse=False, scope=None, name="ResidualBlock"): # residual shrinkage blocks with channel-wise thresholds residual = incoming in_channels = incoming.get_shape().as_list()[-1] # Variable Scope fix for older TF try: vscope = tf.variable_scope(scope, default_name=name, values=[incoming], reuse=reuse) except Exception: vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse) with vscope as scope: name = scope.name #TODO for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, downsample_strides, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, 1, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) # get thresholds and apply thresholding abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True) scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tflearn.batch_normalization(scales) scales = tflearn.activation(scales, 'relu') scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1) thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales)) # soft thresholding residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0)) # Downsampling if downsample_strides > 1: identity = tflearn.avg_pool_2d(identity, 1, downsample_strides) # Projection to new dimension if in_channels != out_channels: if (out_channels - in_channels) % 2 == 0: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]]) else: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch+1]]) in_channels = out_channels residual = residual + identity return residual # Real-time data preprocessing img_prep = tflearn.ImagePreprocessing() img_prep.add_featurewise_zero_center(per_channel=True) # Real-time data augmentation img_aug = tflearn.ImageAugmentation() img_aug.add_random_flip_leftright() img_aug.add_random_crop([32, 32], padding=4) # Build a Deep Residual Shrinkage Network with 3 blocks net = tflearn.input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug) net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001) net = residual_shrinkage_block(net, 1, 16) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu') net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, 10, activation='softmax') mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True) net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') # Training model = tflearn.DNN(net, checkpoint_path='model_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500, show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10') training_acc = model.evaluate(X, Y)[0] validation_acc = model.evaluate(testX, testY)[0]
文獻來源
M. Zhao, S, Zhong, X. Fu, et al. Deep residual shrinkage networks for fault diagnosis. IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/document/8850096/
源代碼
https://github.com/zhao62/Deep-Residual-Shrinkage-Networks