TensorFlow中的卷積函數


前言

最近嘗試看TensorFlow中Slim模塊的代碼,看的比較郁悶,所以試着寫點小的代碼,動手驗證相關的操作,以增加直觀性。

卷積函數

slim模塊的conv2d函數,是二維卷積接口,順着源代碼可以看到最終調的TensorFlow接口是convolution,這個地方就進入C++層面了,暫時不涉及。先來看看這個convolution函數,官方定義是這樣的:

tf.nn.convolution(
    input,
    filter,
    padding,
    strides=None,
    dilation_rate=None,
    name=None,
    data_format=None
)

其中在默認情況下,也就是data_format=None的時候,input的要求格式是[batch_size] + input_spatial_shape + [in_channels],  也就是要求第一維是batch,最后一維是channel,中間是真正的卷積維度。所以這個接口不僅只支持2維卷積,猜測2維卷積tf.nn.conv2d是對此接口的封裝。[batch, height, weight, channel]就是conv2d的input參數格式,batch就是樣本數,或者更狹隘一點,圖片數量,height是圖片高,weight是圖片的寬,Slim的分類網絡都是height=weight的,以實現方陣運算,所有slim模塊中的原始圖片都需要經過預處理過程,這里不展開。

filter參數是卷積核的定義,spatial_filter_shape + [in_channels, out_channels],對於2維卷積同樣是4維參數[weight, height, channel, out_channel]。

明明是2維卷積,輸入都是4維,已經有點抽象了,所以進入下一個階段,寫段代碼,驗證一下吧。

實踐一下

這個例子先定義一個3X3的圖片,再定義一個2X2的卷積核,代碼如下:

import tensorflow as tf

input = tf.constant(
[
        [
                [
                        [100., 100., 100.],
                        [100., 100., 100.],
                        [100., 100., 100.]
                ],
                [
                        [100., 100., 100.],
                        [100., 100., 100.],
                        [100., 100., 100.]
                ],
                [
                        [100., 100., 100.],
                        [100., 100., 100.],
                        [100., 100., 100.],
                ]
        ]
]
);


filter = tf.constant(
[
        [
                [
                        [0.5],
                        [0.5],
                        [0.5]
                ],
                [
                        [0.5],
                        [0.5],
                        [0.5]
                ]
        ],
        [
                [
                        [0.5],
                        [0.5],
                        [0.5]
                ],
                [
                        [0.5],
                        [0.5],
                        [0.5]
                ]
        ],
]
);

result = tf.nn.convolution(input, filter, padding='VALID');

with tf.Session() as sess:
        print sess.run(result)

 從上述代碼可以看到,input的shape是[1, 3, 3, 3],filter的shape是[2, 2, 3, 1 ],卷積的過程在方陣[3, 3] 和 核[2, 2]上展開,並且由於有三個通道,每個通道分別卷積后求和。

代碼的執行結果:

[

  [

    [

      [600.]
      [600.]

    ]

    [

      [600.]

      [600.]

    ]

  ]

]

由於我們填的padding參數是VALID,所以最后的結果矩陣面積會縮小,滿足(3-2)+1,即 (iw - kw) + 1。

以上例子,我們可以將它稱為單張圖片二維3通道卷積,所以計算過程應該是每個通道進行卷積后最后三個通道的數值累加。

如果是從單個通道看,input就是:

[

  [100., 100., 100,]

  [100., 100., 100,]

  [100., 100., 100,]

]

卷積核:

[

  [0.5, 0.5]

  [0.5, 0.5]

]

那么單層卷積結果:

[

  [200., 200.]

  [200., 200.]

]

將三層結果疊加就是程序輸出結果。

增加輸出通道

slim.conv2d函數的第二參數就是輸出通道的數量,就是對應convolution接口filter的第4維,我們把程序改一下,增加一個輸出通道:

filter = tf.constant(
[
        [
                [
                        [0.5, 0.1],
                        [0.5, 0.1],
                        [0.5, 0.1]
                ],
                [
                        [0.5, 0.1],
                        [0.5, 0.1],
                        [0.5, 0.1]
                ]
        ],
        [
                [
                        [0.5, 0.1],
                        [0.5, 0.1],
                        [0.5, 0.1]
                ],
                [
                        [0.5, 0.1],
                        [0.5, 0.1],
                        [0.5, 0.1]
                ]
        ],
]
);

最后的輸出結果:

[

  [

    [

      [600. 120.]
      [600. 120.]

    ]
    [

      [600. 120.]
      [600. 120.]

    ]

  ]

]

其中 120 = 3 * (100 * 0.1 + 100 * 0.1 + 100 * 0.1 + 100 * 0.1)

從結果可以看到,輸出結果滿足 [batch_size] + output_spatial_shape + [out_channels]的格式。

padding=SAME更常用

上面的例子中使用了padding=VALID,是指不填充的情況下進行的有效卷積結果矩陣面積會收縮。而我們在閱卷幾個經典網絡時,都是使用padding=SAME的方式,這種方式下,結果輸出矩陣形狀不變,這樣就便於對不同分支結果進行連接等操作。

將第一個例子中的padding改為SAME,輸出結果為:

[

  [

    [

      [600.]
      [600.]
      [300.]

    ]
    [

      [600.]
      [600.]
      [300.]

    ]

    [

      [300.]
      [300.]
      [150.]

    ]

  ]

]

在SAME模式下,為了保證輸出結果輸入輸入形狀一致,實時上在原矩陣的的右側和底部擴展了行、列 0

暫時性結束

作為新手,一旦碰到多維就蒙了,所有以上的實踐,都是只是為了增加理解。

 




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM