深度殘差收縮網絡:一種新的深度注意力機制算法


本文解讀了一種新的深度注意力算法,即深度殘差收縮網絡(Deep Residual Shrinkage Network)。從功能上講,深度殘差收縮網絡是一種面向強噪聲或者高度冗余數據的特征學習方法。本文首先回顧了相關基礎知識,然后介紹了深度殘差收縮網絡的動機和具體實現,希望對大家有所幫助。 

1.相關基礎 

深度殘差收縮網絡主要建立在三個部分的基礎之上:深度殘差網絡、軟閾值函數和注意力機制。 

1.1 深度殘差網絡 

深度殘差網絡無疑是近年來最成功的深度學習算法之一,在谷歌學術上的引用已經突破四萬次。相較於普通的卷積神經網絡,深度殘差網絡采用跨層恆等路徑的方式,緩解了深層網絡的訓練難度。深度殘差網絡的主干部分是由很多殘差模塊堆疊而成的,其中一種常見的殘差模塊如下圖所示。

 1.2 軟閾值函數

軟閾值函數是大部分降噪方法的核心步驟。首先,我們需要設置一個正數閾值。該閾值不能太大,即不能大於輸入數據絕對值的最大值,否則輸出會全部為零。然后,軟閾值函數會將絕對值低於這個閾值的輸入數據設置為零,並且將絕對值大於這個閾值的輸入數據也朝着零收縮,其輸入與輸出的關系如下圖(a)所示。

軟閾值函數的輸出y對輸入x的導數如上圖(b)所示。我們可以發現,其導數要么取值為0,要么取值為1。從這個角度看的話,軟閾值函數和ReLU激活函數有一定的相似之處,也有利於深度學習算法訓練時梯度的反向傳播。值得注意的是,閾值的選取對軟閾值函數的結果有着直接的影響,至今仍是一個難題。

1.3 注意力機制

注意力機制是近年來深度學習領域的超級研究熱點,而Squeeze-and-Excitation Network (SENet)則是最為經典的注意力算法之一。如下圖所示,SENet通過一個小型網絡學習得到一組權值系數,用於各個特征通道的加權。這其實是一種注意力機制:首先評估各個特征通道的重要程度,然后根據其重要程度賦予各個特征通道合適的權重。

如下圖所示,SENet可以與殘差模塊集成在一起。在這種模式下,由於跨層恆等路徑的存在,SENet可以更容易得到訓練。另外,值得指出的是,每個樣本的權值系數都是根據其自身設置的;也就是說,每個樣本都可以有自己獨特的一組權值系數

2.深度殘差收縮網絡

接下來,本部分針對深度殘差收縮網絡的動機、實現、優勢和驗證,分別展開了介紹。

2.1 動機

首先,大部分現實世界中的數據,包括圖片、語音或者振動,都或多或少地含有噪聲或者冗余信息。從廣義上講,在一個樣本里面,任何與當前模式識別任務無關的信息,都可以被認為是噪聲或者冗余信息。這些噪聲或者冗余信息很可能會對當前的模式識別任務造成不利的影響。

其次,對於任意的兩個樣本,它們的噪聲或冗余含量經常是不同的。換言之,有些樣本所含的噪聲或冗余要多一些,有些要少一些。這就要求我們在設計算法的時候,應該使算法具備根據每個樣本的特點、單獨設置相關參數的能力。

在上述兩點的驅動下,我們能不能將傳統信號降噪算法中的軟閾值函數引入深度殘差網絡之中呢?軟閾值函數中的閾值應該怎樣選取呢?深度殘差收縮網絡就給出了一種答案。

2.2 實現

深度殘差收縮網絡融合了深度殘差網絡、SENet和軟閾值函數。如下圖所示,深度殘差收縮網絡就是將殘差模式下的SENet中的“重新加權”替換成了“軟閾值化”。在SENet中,所嵌入的小型網絡是用於獲取一組權值系數;在深度殘差收縮網絡中,該小型網絡則是用於獲取一組閾值。

為了獲得合適的閾值,相較於原始的SENet,深度殘差收縮網絡里面的小型網絡的結構也進行了調整。具體而言,該小型網絡所輸出的閾值,是(各個特征通道的絕對值的平均值)×(一組0和1之間的系數)。通過這種方式,深度殘差收縮網絡不僅確保了所有閾值都為正數,而且閾值不會太大(不會使所有輸出都為0)。 

如下圖所示,深度殘差收縮網絡的整體結構與普通的深度殘差網絡是一致的,包含了輸入層、剛開始的卷積層、一系列的基本模塊以及最后的全局均值池化和全連接輸出層等。

 2.3 優勢

 首先,軟閾值函數所需要的閾值,是通過一個小型網絡自動設置的,避免了人工設置閾值所需要的專業知識。

 然后,深度殘差收縮網絡確保了軟閾值函數的閾值為正數,而且在合適的取值范圍之內,避免了輸出全部為零的情況。

 同時,每個樣本都有自己獨特的一組閾值,使得深度殘差收縮網絡適用於各個樣本的噪聲含量不同的情況。

 3.結論

由於噪聲或者冗余信息是無處不在的,深度殘差收縮網絡,或者說這種“注意力機制”+“軟閾值函數”的思路,或許有着廣闊的拓展空間和應用范圍。

 原文:

M. Zhao, S. Zhong, X. Fu, B. Tang, and M. Pecht, “Deep Residual Shrinkage Networks for Fault Diagnosis,” IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 

Keras代碼: 

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""

from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)

# Input image dimensions
img_rows, img_cols = 28, 28

# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)


def abs_backend(inputs):
    return K.abs(inputs)

def expand_dim_backend(inputs):
    return K.expand_dims(K.expand_dims(inputs,1),1)

def sign_backend(inputs):
    return K.sign(inputs)

def pad_backend(inputs, in_channels, out_channels):
    pad_dim = (out_channels - in_channels)//2
    inputs = K.expand_dims(inputs,-1)
    inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
    return K.squeeze(inputs, -1)

# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                             downsample_strides=2):
    
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
    
    for i in range(nb_blocks):
        
        identity = residual
        
        if not downsample:
            downsample_strides = 1
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), 
                          padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        # Calculate global means
        residual_abs = Lambda(abs_backend)(residual)
        abs_mean = GlobalAveragePooling2D()(residual_abs)
        
        # Calculate scaling coefficients
        scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', 
                       kernel_regularizer=l2(1e-4))(abs_mean)
        scales = BatchNormalization()(scales)
        scales = Activation('relu')(scales)
        scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
        scales = Lambda(expand_dim_backend)(scales)
        
        # Calculate thresholds
        thres = keras.layers.multiply([abs_mean, scales])
        
        # Soft thresholding
        sub = keras.layers.subtract([residual_abs, thres])
        zeros = keras.layers.subtract([sub, sub])
        n_sub = keras.layers.maximum([sub, zeros])
        residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
        
        # Downsampling (it is important to use the pooL-size of (1, 1))
        if downsample_strides > 1:
            identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
            
        # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution)
        if in_channels != out_channels:
            identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
        
        residual = keras.layers.add([residual, identity])
    
    return residual


# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))

# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])

 TFLearn代碼 

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
 
@author: super_9527
"""
  
from __future__ import division, print_function, absolute_import
  
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
  
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
  
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
  
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
  
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
      
    # residual shrinkage blocks with channel-wise thresholds
  
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
  
    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
  
    with vscope as scope:
        name = scope.name #TODO
  
        for i in range(nb_blocks):
  
            identity = residual
  
            if not downsample:
                downsample_strides = 1
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
              
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            # soft thresholding
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
              
  
            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)
  
            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels
  
            residual = residual + identity
  
    return residual
  
  
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
  
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
  
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)
  
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
  
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM