【600】Attention U-Net 解釋


參考:Attention-UNet for Pneumothorax Segmentation 

參考:Attention U-Net


一、Model 結構圖

  說明:這是3D的數據,F代表 feature( channel),H 代表 height, W 代表 width, D代表 depth,就是3D數據塊的深度。對於普通的圖片數據可以刪除掉 D,另外就是會把通道放后面,因此可以表示為 $H_1 \times W_1 \times F_1$。

二、AttnBlock2D 函數的圖示

  下圖為 AttnBlock2D 函數的實現效果,輸出結果相當於 U-Net skip connection 的連接 layer,后面需要接一個 Concatenation

  以上為 Attention Gate 的原始結構圖,可以按照下面的結構圖進行理解:

  • 輸入為 $x$(最上 conv2d_126,分成兩個線路)和 $g$(左邊 up_sampling_2d_11)

  • $x$ 經過一個卷積、$g$ 經過一個卷積,然后兩者做個加法

    • $x$ 經過一個卷積的 通道 數量為 x.channels // 4
    • $g$ 經過一個卷積的 通道 數量為 x.channels // 4
  • 之后連續的 ReLU、卷積、Sigmod,得到權重圖片,如下圖的 activation_19

    • 卷積的 通道 數量為 1,可以之后進行相乘,Attention
  • 最后將 activation_19 與 $x$(最上 conv2d_126) 進行相乘,就完成了整個過程

  實現代碼:

from keras import Input 
from keras.layers import Conv2D, Activation, UpSampling2D, Lambda, Dropout, MaxPooling2D, multiply, add
from keras import backend as K 
from keras.models import Model 

IMG_CHANNEL = 3

def AttnBlock2D(x, g, inter_channel, data_format='channels_first'):

    theta_x = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(x)

    phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(g)

    f = Activation('relu')(add([theta_x, phi_g]))

    psi_f = Conv2D(1, [1, 1], strides=[1, 1], data_format=data_format)(f)

    rate = Activation('sigmoid')(psi_f)

    att_x = multiply([x, rate])

    return att_x


def attention_up_and_concate(down_layer, layer, data_format='channels_first'):
    
    if data_format == 'channels_first':
        in_channel = down_layer.get_shape().as_list()[1]
    else:
        in_channel = down_layer.get_shape().as_list()[3]
    
    up = UpSampling2D(size=(2, 2), data_format=data_format)(down_layer)
    layer = AttnBlock2D(x=layer, g=up, inter_channel=in_channel // 4, data_format=data_format)

    if data_format == 'channels_first':
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=1))
    else:
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))  # 參考代碼這個地方寫錯了,x[1] 寫成了 x[3]
    
    concate = my_concat([up, layer])
    return concate

# Attention U-Net 
def att_unet(img_w, img_h, n_label, data_format='channels_first'):
    # inputs = (3, 160, 160)
    inputs = Input((IMG_CHANNEL, img_w, img_h))
    x = inputs
    depth = 4
    features = 32
    skips = []
    # depth = 0, 1, 2, 3
    for i in range(depth):
        # ENCODER
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
        skips.append(x)
        x = MaxPooling2D((2, 2), data_format='channels_first')(x)
        features = features * 2

    # BOTTLENECK
    x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
    x = Dropout(0.2)(x)
    x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)

    # DECODER
    for i in reversed(range(depth)):
        features = features // 2
        x = attention_up_and_concate(x, skips[i], data_format=data_format)
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same', data_format=data_format)(x)
    
    conv6 = Conv2D(n_label, (1, 1), padding='same', data_format=data_format)(x)
    conv7 = Activation('sigmoid')(conv6)
    
    model = Model(inputs=inputs, outputs=conv7)

    return model

IMG_WIDTH = 160
IMG_HEIGHT = 160

model = att_unet(IMG_WIDTH, IMG_HEIGHT, n_label=1)
model.summary()

from keras.utils.vis_utils import plot_model 
plot_model(model, to_file='Att_U_Net.png', show_shapes=True)

   輸出:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_11 (InputLayer)           [(None, 3, 160, 160) 0                                            
__________________________________________________________________________________________________
conv2d_119 (Conv2D)             (None, 32, 160, 160) 896         input_11[0][0]                   
__________________________________________________________________________________________________
dropout_45 (Dropout)            (None, 32, 160, 160) 0           conv2d_119[0][0]                 
__________________________________________________________________________________________________
conv2d_120 (Conv2D)             (None, 32, 160, 160) 9248        dropout_45[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_32 (MaxPooling2D) (None, 32, 80, 80)   0           conv2d_120[0][0]                 
__________________________________________________________________________________________________
conv2d_121 (Conv2D)             (None, 64, 80, 80)   18496       max_pooling2d_32[0][0]           
__________________________________________________________________________________________________
dropout_46 (Dropout)            (None, 64, 80, 80)   0           conv2d_121[0][0]                 
__________________________________________________________________________________________________
conv2d_122 (Conv2D)             (None, 64, 80, 80)   36928       dropout_46[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_33 (MaxPooling2D) (None, 64, 40, 40)   0           conv2d_122[0][0]                 
__________________________________________________________________________________________________
conv2d_123 (Conv2D)             (None, 128, 40, 40)  73856       max_pooling2d_33[0][0]           
__________________________________________________________________________________________________
dropout_47 (Dropout)            (None, 128, 40, 40)  0           conv2d_123[0][0]                 
__________________________________________________________________________________________________
conv2d_124 (Conv2D)             (None, 128, 40, 40)  147584      dropout_47[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_34 (MaxPooling2D) (None, 128, 20, 20)  0           conv2d_124[0][0]                 
__________________________________________________________________________________________________
conv2d_125 (Conv2D)             (None, 256, 20, 20)  295168      max_pooling2d_34[0][0]           
__________________________________________________________________________________________________
dropout_48 (Dropout)            (None, 256, 20, 20)  0           conv2d_125[0][0]                 
__________________________________________________________________________________________________
conv2d_126 (Conv2D)             (None, 256, 20, 20)  590080      dropout_48[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_35 (MaxPooling2D) (None, 256, 10, 10)  0           conv2d_126[0][0]                 
__________________________________________________________________________________________________
conv2d_127 (Conv2D)             (None, 512, 10, 10)  1180160     max_pooling2d_35[0][0]           
__________________________________________________________________________________________________
dropout_49 (Dropout)            (None, 512, 10, 10)  0           conv2d_127[0][0]                 
__________________________________________________________________________________________________
conv2d_128 (Conv2D)             (None, 512, 10, 10)  2359808     dropout_49[0][0]                 
__________________________________________________________________________________________________
up_sampling2d_11 (UpSampling2D) (None, 512, 20, 20)  0           conv2d_128[0][0]                 
__________________________________________________________________________________________________
conv2d_129 (Conv2D)             (None, 128, 20, 20)  32896       conv2d_126[0][0]                 
__________________________________________________________________________________________________
conv2d_130 (Conv2D)             (None, 128, 20, 20)  65664       up_sampling2d_11[0][0]           
__________________________________________________________________________________________________
add_6 (Add)                     (None, 128, 20, 20)  0           conv2d_129[0][0]                 
                                                                 conv2d_130[0][0]                 
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 128, 20, 20)  0           add_6[0][0]                      
__________________________________________________________________________________________________
conv2d_131 (Conv2D)             (None, 1, 20, 20)    129         activation_18[0][0]              
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 1, 20, 20)    0           conv2d_131[0][0]                 
__________________________________________________________________________________________________
multiply_6 (Multiply)           (None, 256, 20, 20)  0           conv2d_126[0][0]                 
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 768, 20, 20)  0           up_sampling2d_11[0][0]           
                                                                 multiply_6[0][0]                 
__________________________________________________________________________________________________
conv2d_132 (Conv2D)             (None, 256, 20, 20)  1769728     lambda_5[0][0]                   
__________________________________________________________________________________________________
dropout_50 (Dropout)            (None, 256, 20, 20)  0           conv2d_132[0][0]                 
__________________________________________________________________________________________________
conv2d_133 (Conv2D)             (None, 256, 20, 20)  590080      dropout_50[0][0]                 
__________________________________________________________________________________________________
up_sampling2d_12 (UpSampling2D) (None, 256, 40, 40)  0           conv2d_133[0][0]                 
__________________________________________________________________________________________________
conv2d_134 (Conv2D)             (None, 64, 40, 40)   8256        conv2d_124[0][0]                 
__________________________________________________________________________________________________
conv2d_135 (Conv2D)             (None, 64, 40, 40)   16448       up_sampling2d_12[0][0]           
__________________________________________________________________________________________________
add_7 (Add)                     (None, 64, 40, 40)   0           conv2d_134[0][0]                 
                                                                 conv2d_135[0][0]                 
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 64, 40, 40)   0           add_7[0][0]                      
__________________________________________________________________________________________________
conv2d_136 (Conv2D)             (None, 1, 40, 40)    65          activation_20[0][0]              
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 1, 40, 40)    0           conv2d_136[0][0]                 
__________________________________________________________________________________________________
multiply_7 (Multiply)           (None, 128, 40, 40)  0           conv2d_124[0][0]                 
                                                                 activation_21[0][0]              
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 384, 40, 40)  0           up_sampling2d_12[0][0]           
                                                                 multiply_7[0][0]                 
__________________________________________________________________________________________________
conv2d_137 (Conv2D)             (None, 128, 40, 40)  442496      lambda_6[0][0]                   
__________________________________________________________________________________________________
dropout_51 (Dropout)            (None, 128, 40, 40)  0           conv2d_137[0][0]                 
__________________________________________________________________________________________________
conv2d_138 (Conv2D)             (None, 128, 40, 40)  147584      dropout_51[0][0]                 
__________________________________________________________________________________________________
up_sampling2d_13 (UpSampling2D) (None, 128, 80, 80)  0           conv2d_138[0][0]                 
__________________________________________________________________________________________________
conv2d_139 (Conv2D)             (None, 32, 80, 80)   2080        conv2d_122[0][0]                 
__________________________________________________________________________________________________
conv2d_140 (Conv2D)             (None, 32, 80, 80)   4128        up_sampling2d_13[0][0]           
__________________________________________________________________________________________________
add_8 (Add)                     (None, 32, 80, 80)   0           conv2d_139[0][0]                 
                                                                 conv2d_140[0][0]                 
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 32, 80, 80)   0           add_8[0][0]                      
__________________________________________________________________________________________________
conv2d_141 (Conv2D)             (None, 1, 80, 80)    33          activation_22[0][0]              
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 1, 80, 80)    0           conv2d_141[0][0]                 
__________________________________________________________________________________________________
multiply_8 (Multiply)           (None, 64, 80, 80)   0           conv2d_122[0][0]                 
                                                                 activation_23[0][0]              
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 192, 80, 80)  0           up_sampling2d_13[0][0]           
                                                                 multiply_8[0][0]                 
__________________________________________________________________________________________________
conv2d_142 (Conv2D)             (None, 64, 80, 80)   110656      lambda_7[0][0]                   
__________________________________________________________________________________________________
dropout_52 (Dropout)            (None, 64, 80, 80)   0           conv2d_142[0][0]                 
__________________________________________________________________________________________________
conv2d_143 (Conv2D)             (None, 64, 80, 80)   36928       dropout_52[0][0]                 
__________________________________________________________________________________________________
up_sampling2d_14 (UpSampling2D) (None, 64, 160, 160) 0           conv2d_143[0][0]                 
__________________________________________________________________________________________________
conv2d_144 (Conv2D)             (None, 16, 160, 160) 528         conv2d_120[0][0]                 
__________________________________________________________________________________________________
conv2d_145 (Conv2D)             (None, 16, 160, 160) 1040        up_sampling2d_14[0][0]           
__________________________________________________________________________________________________
add_9 (Add)                     (None, 16, 160, 160) 0           conv2d_144[0][0]                 
                                                                 conv2d_145[0][0]                 
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 16, 160, 160) 0           add_9[0][0]                      
__________________________________________________________________________________________________
conv2d_146 (Conv2D)             (None, 1, 160, 160)  17          activation_24[0][0]              
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 1, 160, 160)  0           conv2d_146[0][0]                 
__________________________________________________________________________________________________
multiply_9 (Multiply)           (None, 32, 160, 160) 0           conv2d_120[0][0]                 
                                                                 activation_25[0][0]              
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 96, 160, 160) 0           up_sampling2d_14[0][0]           
                                                                 multiply_9[0][0]                 
__________________________________________________________________________________________________
conv2d_147 (Conv2D)             (None, 32, 160, 160) 27680       lambda_8[0][0]                   
__________________________________________________________________________________________________
dropout_53 (Dropout)            (None, 32, 160, 160) 0           conv2d_147[0][0]                 
__________________________________________________________________________________________________
conv2d_148 (Conv2D)             (None, 32, 160, 160) 9248        dropout_53[0][0]                 
__________________________________________________________________________________________________
conv2d_149 (Conv2D)             (None, 1, 160, 160)  33          conv2d_148[0][0]                 
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 1, 160, 160)  0           conv2d_149[0][0]                 
==================================================================================================
Total params: 7,977,941
Trainable params: 7,977,941
Non-trainable params: 0
__________________________________________________________________________________________________

   結構圖如下:

  針對通道在最后的代碼補充:

from keras import Input 
from keras.layers import Conv2D, Activation, UpSampling2D, Lambda, Dropout, MaxPooling2D, multiply, add
from keras import backend as K 
from keras.models import Model 

IMG_CHANNEL = 3

def AttnBlock2D(x, g, inter_channel):
    # x: skip connection layer
    # g: down layer upsampling 后的 layer
    # inner_channel: down layer 的通道數 // 4
    
    theta_x = Conv2D(inter_channel, [1, 1], strides=[1, 1])(x)
    phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1])(g)
    f = Activation('relu')(add([theta_x, phi_g]))
    psi_f = Conv2D(1, [1, 1], strides=[1, 1])(f)
    rate = Activation('sigmoid')(psi_f)
    att_x = multiply([x, rate])

    return att_x

def attention_up_and_concate(down_layer, layer):
    # down_layer: 承接下來的 layer
    # layer: skip connection layer
    
    in_channel = down_layer.get_shape().as_list()[3]
    up = UpSampling2D(size=(2, 2))(down_layer)
    layer = AttnBlock2D(x=layer, g=up, inter_channel=in_channel // 4)
    my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))
    concate = my_concat([up, layer])
    
    return concate

# Attention U-Net 
def att_unet(img_w, img_h, n_label):
    inputs = Input((img_w, img_h, IMG_CHANNEL))
    x = inputs
    depth = 4
    features = 32
    skips = []
    
    # depth = 0, 1, 2, 3
    # ENCODER
    for i in range(depth):
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
        skips.append(x)
        x = MaxPooling2D((2, 2))(x)
        features = features * 2

    # BOTTLENECK
    x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
    x = Dropout(0.2)(x)
    x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)

    # DECODER
    for i in reversed(range(depth)):
        features = features // 2
        x = attention_up_and_concate(x, skips[i])
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
        x = Dropout(0.2)(x)
        x = Conv2D(features, (3, 3), activation='relu', padding='same')(x)
    
    conv6 = Conv2D(n_label, (1, 1), padding='same')(x)
    conv7 = Activation('sigmoid')(conv6)
    
    model = Model(inputs=inputs, outputs=conv7)

    return model

IMG_WIDTH = 160
IMG_HEIGHT = 160

model = att_unet(IMG_WIDTH, IMG_HEIGHT, n_label=1)
model.summary()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM