簡介
這里的生成式網絡是廣義的生成式,不僅僅指gan網絡,還有風格遷移中的類自編碼器網絡,以及語義分割中的類自編碼器網絡,因為遇到次數比較多,所以簡單的記錄一下。
背景
1、像素和數字
圖像處理目標一般就是RGB三色通道,原始圖像解碼后是0~255,這個矩陣傳給matplotlib就可以直接繪圖了,與此同0~1的圖像matplotlib也是可以接受的,關於這點,我們來看看文檔是怎么說的,
Elements of RGB and RGBA arrays represent pixels of an MxN image.
All values should be in the range [0 .. 1] for floats or
[0 .. 255] for integers. Out-of-range values will be clipped to
these bounds.
即使0~1也能夠使用,我們常用的還是0~255的數據。
2、生成式網絡輸出的限制
生成式網絡不同於分類網絡,其輸出的目標是圖像,對照上面也就是0~255范圍(這個更常用)的矩陣,這就意味着網絡的輸出有所限制的,且是不同於分類網絡全部限制於0~1或者-1~1的,正如分類網絡的sigmoid或者softmax一樣,我們會在最后一個卷積/轉置卷積層后采取一些操作保證輸出滿足圖像的要求。
實際思路
輸入圖像為了保證可以被用於loss,需要和輸出圖像的值域相同,所以有兩個思路:
- 輸入圖像值壓縮到-1~1附近
- 輸出圖像值放大到0~255
gan網絡中
我們采用方式為:原像素數據除以127.5減去1的操作,使得輸出值保持在-1~1之間,可以配合sigmoid激活函數進行學習
實際測試一下,我們將這里的預處理(TFR_process.py)做一下調整,使得值不再被壓縮,
'''圖像預處理''' # image_decode = tf.cast(image_decode, tf.float32)/127.5-1 image_decode = tf.cast(image_decode, tf.float32)
相應的將生成網絡(DCGAN_function.py)作出調整,
h4 = deconv2d(h3, [batch_size, s_h, s_w, c_dim], scope='g_h4') return h4 # tf.nn.tanh(h4)
可以看到結果依舊可以訓練出來,效果如下。
快速風格遷移中
我們采用0~255作為輸入,生成數據仍為0~255(主要分布),然后將輸出數據進一步操作,送入vgg進行loss計算。
此時的生成式網絡最后一層可以不加激活,輸出會自行收斂在目標附近,也可以tanh激活(-1~1)后加1再乘127.5。