深度殘差網絡ResNet獲得了2016年IEEE Conference on Computer Vision and Pattern Recognition的最佳論文獎,目前在谷歌學術的引用量已高達38295次。
深度殘差收縮網絡是深度殘差網絡的一種的改進版本,其實是深度殘差網絡、注意力機制和軟閾值函數的集成。
在一定程度上,深度殘差收縮網絡的工作原理,可以理解為:通過注意力機制注意到不重要的特征,通過軟閾值函數將它們置為零;或者說,通過注意力機制注意到重要的特征,將它們保留下來,從而加強深度神經網絡從含噪聲信號中提取有用特征的能力。
1.為什么要提出深度殘差收縮網絡呢?
首先,在對樣本進行分類的時候,樣本中不可避免地會有一些噪聲,就像高斯噪聲、粉色噪聲、拉普拉斯噪聲等。更廣義地講,樣本中很可能包含着與當前分類任務無關的信息,這些信息也可以理解為噪聲。這些噪聲可能會對分類效果產生不利的影響。(軟閾值化是許多信號降噪算法中的一個關鍵步驟)
舉例來說,在馬路邊聊天的時候,聊天的聲音里就可能會混雜車輛的鳴笛聲、車輪聲等等。當對這些聲音信號進行語音識別的時候,識別效果不可避免地會受到鳴笛聲、車輪聲的影響。從深度學習的角度來講,這些鳴笛聲、車輪聲所對應的特征,就應該在深度神經網絡內部被刪除掉,以避免對語音識別的效果造成影響。
其次,即使是同一個樣本集,各個樣本的噪聲量也往往是不同的。(這和注意力機制有相通之處;以一個圖像樣本集為例,各張圖片中目標物體所在的位置可能是不同的;注意力機制可以針對每一張圖片,注意到目標物體所在的位置)
例如,當訓練貓狗分類器的時候,對於標簽為“狗”的5張圖像,第1張圖像可能同時包含着狗和老鼠,第2張圖像可能同時包含着狗和鵝,第3張圖像可能同時包含着狗和雞,第4張圖像可能同時包含着狗和驢,第5張圖像可能同時包含着狗和鴨子。我們在訓練貓狗分類器的時候,就不可避免地會受到老鼠、鵝、雞、驢和鴨子等無關物體的干擾,造成分類准確率下降。如果我們能夠注意到這些無關的老鼠、鵝、雞、驢和鴨子,將它們所對應的特征刪除掉,就有可能提高貓狗分類器的准確率。
2.軟閾值化是許多信號降噪算法的核心步驟
軟閾值化,是很多信號降噪算法的核心步驟,將絕對值小於某個閾值的特征刪除掉,將絕對值大於這個閾值的特征朝着零的方向進行收縮。它可以通過以下公式來實現:
軟閾值化的輸出對於輸入的導數為
由上可知,軟閾值化的導數要么是1,要么是0。這個性質是和ReLU激活函數是相同的。因此,軟閾值化也能夠減小深度學習算法遭遇梯度彌散和梯度爆炸的風險。
在軟閾值化函數中,閾值的設置必須符合兩個的條件: 第一,閾值是正數;第二,閾值不能大於輸入信號的最大值,否則輸出會全部為零。
同時,閾值最好還能符合第三個條件:每個樣本應該根據自身的噪聲含量,有着自己獨立的閾值。
這是因為,很多樣本的噪聲含量經常是不同的。例如經常會有這種情況,在同一個樣本集里面,樣本A所含噪聲較少,樣本B所含噪聲較多。那么,如果是在降噪算法里進行軟閾值化的時候,樣本A就應該采用較大的閾值,樣本B就應該采用較小的閾值。在深度神經網絡中,雖然這些特征和閾值失去了明確的物理意義,但是基本的道理還是相通的。也就是說,每個樣本應該根據自身的噪聲含量,有着自己獨立的閾值。
3.注意力機制
注意力機制在計算機視覺領域是比較容易理解的。動物的視覺系統可以快速掃描全部區域,發現目標物體,進而將注意力集中在目標物體上,以提取更多的細節,同時抑制無關信息。具體請參照注意力機制方面的文章。
Squeeze-and-Excitation Network(SENet)是一種較新的注意力機制下的深度學習方法。 在不同的樣本中,不同的特征通道,在分類任務中的貢獻大小,往往是不同的。SENet采用一個小型的子網絡,獲得一組權重,進而將這組權重與各個通道的特征分別相乘,以調整各個通道特征的大小。這個過程,就可以認為是在施加不同大小的注意力在各個特征通道上。
在這種方式下,每一個樣本,都會有自己獨立的一組權重。換言之,任意的兩個樣本,它們的權重,都是不一樣的。在SENet中,獲得權重的具體路徑是,“全局池化→全連接層→ReLU函數→全連接層→Sigmoid函數”。
4.深度注意力機制下的軟閾值化
深度殘差收縮網絡借鑒了上述SENet的子網絡結構,以實現深度注意力機制下的軟閾值化。通過藍色框內的子網絡,就可以學習得到一組閾值,對各個特征通道進行軟閾值化。
在這個子網絡中,首先對輸入特征圖的所有特征,求它們的絕對值。然后經過全局均值池化和平均,獲得一個特征,記為A。在另一條路徑中,全局均值池化之后的特征圖,被輸入到一個小型的全連接網絡。這個全連接網絡以Sigmoid函數作為最后一層,將輸出歸一化到0和1之間,獲得一個系數,記為α。最終的閾值可以表示為α×A。因此,閾值就是,一個0和1之間的數字×特征圖的絕對值的平均。這種方式,不僅保證了閾值為正,而且不會太大。
而且,不同的樣本就有了不同的閾值。因此,在一定程度上,可以理解成一種特殊的注意力機制:注意到與當前任務無關的特征,通過軟閾值化,將它們置為零;或者說,注意到與當前任務有關的特征,將它們保留下來。
最后,堆疊一定數量的基本模塊以及卷積層、批標准化、激活函數、全局均值池化以及全連接輸出層等,就得到了完整的深度殘差收縮網絡。
5.深度殘差收縮網絡或許有更廣泛的通用性
深度殘差收縮網絡事實上是一種通用的特征學習方法。這是因為很多特征學習的任務中,樣本中或多或少都會包含一些噪聲,以及不相關的信息。這些噪聲和不相關的信息,有可能會對特征學習的效果造成影響。例如說:
在圖片分類的時候,如果圖片同時包含着很多其他的物體,那么這些物體就可以被理解成“噪聲”;深度殘差收縮網絡或許能夠借助注意力機制,注意到這些“噪聲”,然后借助軟閾值化,將這些“噪聲”所對應的特征置為零,就有可能提高圖像分類的准確率。
在語音識別的時候,如果在聲音較為嘈雜的環境里,比如在馬路邊、工廠車間里聊天的時候,深度殘差收縮網絡也許可以提高語音識別的准確率,或者給出了一種能夠提高語音識別准確率的思路。
6.Keras和TFLearn程序簡介
本程序以圖像分類為例,構建了小型的深度殘差收縮網絡,超參數也未進行優化。為追求高准確率的話,可以適當增加深度,增加訓練迭代次數,以及適當調整超參數。下面是Keras程序:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sat Dec 28 23:24:05 2019 Implemented using TensorFlow 1.0.1 and Keras 2.2.1 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
There might be some problems in the Keras code.
The weights in custom layers of models created using the Keras functional API may not be optimized.
https://www.reddit.com/r/MachineLearning/comments/hrawam/d_theres_a_flawbug_in_tensorflow_thats_preventing/
TensorFlow被曝存嚴重bug,搭配Keras可能丟失權重,至今仍未修復
https://cloud.tencent.com/developer/news/661458
The TFLearn code is recommended for usage.
https://github.com/zhao62/Deep-Residual-Shrinkage-Networks/blob/master/DRSN_TFLearn.py
@author: super_9527 """ from __future__ import print_function import keras import numpy as np from keras.datasets import mnist from keras.layers import Dense, Conv2D, BatchNormalization, Activation from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D from keras.optimizers import Adam from keras.regularizers import l2 from keras import backend as K from keras.models import Model from keras.layers.core import Lambda K.set_learning_phase(1) # Input image dimensions img_rows, img_cols = 28, 28 # The data, split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) # Noised data x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1]) x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1]) print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) def abs_backend(inputs): return K.abs(inputs) def expand_dim_backend(inputs): return K.expand_dims(K.expand_dims(inputs,1),1) def sign_backend(inputs): return K.sign(inputs) def pad_backend(inputs, in_channels, out_channels): pad_dim = (out_channels - in_channels)//2
inputs = K.expand_dims(inputs,-1) inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
return K.squeeze(inputs, -1)
# Residual Shrinakge Block def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2): residual = incoming in_channels = incoming.get_shape().as_list()[-1] for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) # Calculate global means residual_abs = Lambda(abs_backend)(residual) abs_mean = GlobalAveragePooling2D()(residual_abs) # Calculate scaling coefficients scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(abs_mean) scales = BatchNormalization()(scales) scales = Activation('relu')(scales) scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales) scales = Lambda(expand_dim_backend)(scales) # Calculate thresholds thres = keras.layers.multiply([abs_mean, scales]) # Soft thresholding sub = keras.layers.subtract([residual_abs, thres]) zeros = keras.layers.subtract([sub, sub]) n_sub = keras.layers.maximum([sub, zeros]) residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub]) # Downsampling (it is important to use the pooL-size of (1, 1)) if downsample_strides > 1: identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity) # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution) if in_channels != out_channels: identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity) residual = keras.layers.add([residual, identity]) return residual # define and train a model inputs = Input(shape=input_shape) net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs) net = residual_shrinkage_block(net, 1, 8, downsample=True) net = BatchNormalization()(net) net = Activation('relu')(net) net = GlobalAveragePooling2D()(net) outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net) model = Model(inputs=inputs, outputs=outputs) model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test)) # get results K.set_learning_phase(0) DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0) print('Train loss:', DRSN_train_score[0]) print('Train accuracy:', DRSN_train_score[1]) DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0) print('Test loss:', DRSN_test_score[0]) print('Test accuracy:', DRSN_test_score[1])
下面是TFLearn程序:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Dec 23 21:23:09 2019 Implemented using TensorFlow 1.0 and TFLearn 0.3.2 M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: super_9527 """ from __future__ import division, print_function, absolute_import import tflearn import numpy as np import tensorflow as tf from tflearn.layers.conv import conv_2d # Data loading from tflearn.datasets import cifar10 (X, Y), (testX, testY) = cifar10.load_data() # Add noise X = X + np.random.random((50000, 32, 32, 3))*0.1 testX = testX + np.random.random((10000, 32, 32, 3))*0.1 # Transform labels to one-hot format Y = tflearn.data_utils.to_categorical(Y,10) testY = tflearn.data_utils.to_categorical(testY,10) def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2, activation='relu', batch_norm=True, bias=True, weights_init='variance_scaling', bias_init='zeros', regularizer='L2', weight_decay=0.0001, trainable=True, restore=True, reuse=False, scope=None, name="ResidualBlock"): # residual shrinkage blocks with channel-wise thresholds residual = incoming in_channels = incoming.get_shape().as_list()[-1] # Variable Scope fix for older TF try: vscope = tf.variable_scope(scope, default_name=name, values=[incoming], reuse=reuse) except Exception: vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse) with vscope as scope: name = scope.name #TODO for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, downsample_strides, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, 1, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) # get thresholds and apply thresholding abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True) scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tflearn.batch_normalization(scales) scales = tflearn.activation(scales, 'relu') scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1) thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales)) # soft thresholding residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0)) # Downsampling if downsample_strides > 1: identity = tflearn.avg_pool_2d(identity, 1, downsample_strides) # Projection to new dimension if in_channels != out_channels: if (out_channels - in_channels) % 2 == 0: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]]) else: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch+1]]) in_channels = out_channels residual = residual + identity return residual # Real-time data preprocessing img_prep = tflearn.ImagePreprocessing() img_prep.add_featurewise_zero_center(per_channel=True) # Real-time data augmentation img_aug = tflearn.ImageAugmentation() img_aug.add_random_flip_leftright() img_aug.add_random_crop([32, 32], padding=4) # Build a Deep Residual Shrinkage Network with 3 blocks net = tflearn.input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug) net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001) net = residual_shrinkage_block(net, 1, 16) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu') net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, 10, activation='softmax') mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True) net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') # Training model = tflearn.DNN(net, checkpoint_path='model_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500, show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10') training_acc = model.evaluate(X, Y)[0] validation_acc = model.evaluate(testX, testY)[0]
論文網址
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, vol. 16, no. 7, pp. 4681-4690, 2020.
https://ieeexplore.ieee.org/document/8850096
Github主頁: