tensorflow: a Implementation of rotation ops (旋轉的函數實現方法)


tensorflow 旋轉矩陣的函數實現方法

關鍵字: rot90, tensorflow

1. 背景

在做數據增強的操作過程中, 很多情況需要對圖像旋轉和平移等操作, 針對一些特殊的卷積(garbo conv)操作,還需要對卷積核進行旋轉操作.
在tensorflow中似乎沒有實現對4D tensor的旋轉操作.
嚴格的說: tensorflow對tensor的翻轉操作並未實現, 僅有針對3D tensor的tf.image.rot()
而在大多數的情況下使用的是4D形式的tensor, [B,W,H,C] 或者是3D的圖像組成的batchs.

通過查看這篇文章的代碼可以知道[1] 可以使用numpy的rot90()函數旋轉, 但是rot90對象是ndarray, 針對tensorflow.tensor對象而言顯然是無法使用的, 會拋出類似: 無法找到m.dim屬性的異常.
也就是說無法使用numpy.rot90() 函數.

又知, tensorflow中提供有對矩陣的翻轉, 轉置,切片操作的函數,但是沒有提供旋轉90°, 180°,270°的操作.
因此可以參照numpy.rot90(m, k=1, axes=(0,1)) 的程序片段去自己動手實現.
rot90中的第一個參數m是操作對象, k是旋轉的次數,k=1 代表逆時針旋轉90度, k=2 代表逆時針旋轉180度,以此類推
axes是代表旋轉的操作在哪兩個維度構成的平面上.

rot90的源代碼如下:

def rot90(m, k=1, axes=(0,1)):
    '''
    ......
    '''
    # 省略檢測參數的操作
    k %= 4

    if k == 0:
        return m[:]
    if k == 2:
        return flip(flip(m, axes[0]), axes[1])

    axes_list = arange(0, m.ndim)
    (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]],
                                                axes_list[axes[0]])

    if k == 1:
        return transpose(flip(m,axes[1]), axes_list)
    else:
        # k == 3
        return flip(transpose(m, axes_list), axes[1])

PS: 通過閱讀上述的代碼,也可以發現在tensorflow中直接使用rot90所拋出的異常是在這里出現的

if axes[0] == axes[1] or absolute(axes[0] - axes[1]) == m.ndim

原因是: 程序把tensor對象當成np.ndarray操作了, 而tensor對象沒有m.dim屬性

2. 實現rot90操作

2.1 梳理程序流程

通過查看源代碼可以梳理出程序流程圖:

程序流程圖

2.2 tensorflow 實現旋轉操作

根據上述的流程圖, 可以實現對tensorflow的rot90操作;

def rot90(tensor,k=1,axes=[1,2],name=None):
    '''
    autor:lizh
    tensor: a tensor 4 or more dimensions
    k: integer, Number of times the array is rotated by 90 degrees.
    axes: (2,) array_like
        The array is rotated in the plane defined by the axes.
        Axes must be different.
    
    -----
    Returns
    -------
    tensor : tf.tensor
             A rotated view of `tensor`.
    See Also: https://www.tensorflow.org/api_docs/python/tf/image/rot90 
    '''
    axes = tuple(axes)
    if len(axes) != 2:
        raise ValueError("len(axes) must be 2.")
        
    tenor_shape = (tensor.get_shape().as_list())
    dim = len(tenor_shape)
    
    if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == dim:
        raise ValueError("Axes must be different.")
        
    if (axes[0] >= dim or axes[0] < -dim 
        or axes[1] >= dim or axes[1] < -dim):
        
        raise ValueError("Axes={} out of range for tensor of ndim={}."
            .format(axes, dim))
    k%=4
    if k==0:
        return tensor
    if k==2:
        img180 = tf.reverse(tf.reverse(tensor, axis=[axes[0]]),axis=[axes[1]],name=name)
        return img180
    
    axes_list = np.arange(0, dim)
    (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]],axes_list[axes[0]]) # 替換
    
    print(axes_list)
    if k==1:
        img90=tf.transpose(tf.reverse(tensor,axis=[axes[1]]), perm=axes_list, name=name)
        return img90
    if k==3:
        img270=tf.reverse( tf.transpose(tensor, perm=axes_list),axis=[axes[1]],name=name)
        return img270

2.3 代碼測試

# 加載庫
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# 手寫體數據集 加載
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/lizhen/data/MNIST/", one_hot=True)

sess=tf.Session()
#選取數據 4D
images = mnist.train.images
img_raw = images[0,:] # [0,784]
img=tf.reshape(img_raw,[-1,28,28,1]) # img 現在是tensor
# 繪圖
def fig_2D_tensor(tensor):# 繪圖
    #plt.matshow(tensor, cmap=plt.get_cmap('gray'))
    plt.matshow(tensor) # 彩色圖像
    # plt.colorbar() # 顏色條
    plt.show()
# 顯 顯示 待旋轉的圖片
fig_2D_tensor(sess.run(img)[0,:,:,0]) # 提取ndarray

待操作的圖片

簡單的測試一下代碼:

img11_rot=rot90(img,2) # 旋轉兩次90
fig_2D_tensor(sess.run(img11_rot)[0,:,:,0]) # 打印圖像

img12_rot=rot90(img,1,[1,1]) # 拋出異常,  測試 Axes must be different.
img13_rot=rot90(img,1,[0,6]) # 拋出異常,  測試 Axes must be different.

img14_rot=rot90(img,axes=[1,5])# 拋出異常,測試out of range.

img14_rot=rot90(img,axes=[-1,2]) # -1的下標是倒數第二個,測試out of range.

測試結果:

3總結

okey了,現在可以用了.
.....

額,,,,,最近才發現tensorflow的最新版本,大約就在前幾天發布的新版本(14天前, 1.10.1 )上已經添加了對2D,3D圖像的操作,支持[B,W,H,C]格式的tensor做出旋轉[2]

星期五, 07. 九月 2018 02:49下午

參考文獻


  1. Understanding 2D Dilated Convolution Operation with Examples in Numpy and Tensorflow with Interactive Code ↩︎

  2. tensorflow/python/ops/image_ops#rot90 ↩︎


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM