『TensorFlow』張量尺寸獲取


tf.shape(a)和a.get_shape()比較

相同點:都可以得到tensor a的尺寸

不同點:tf.shape()中a 數據的類型可以是tensor, list, array

    a.get_shape()中a的數據類型只能是tensor,且返回的是一個元組(tuple)

import tensorflow as tf  
import numpy as np  

x=tf.constant([[1,2,3],[4,5,6]]  
y=[[1,2,3],[4,5,6]]  
z=np.arange(24).reshape([2,3,4]))  

sess=tf.Session()  
# tf.shape()  
x_shape=tf.shape(x)                    #  x_shape 是一個tensor  
y_shape=tf.shape(y)                    #  <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32>  
z_shape=tf.shape(z)                    #  <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32>  
print sess.run(x_shape)              # 結果:[2 3]  
print sess.run(y_shape)              # 結果:[2 3]  
print sess.run(z_shape)              # 結果:[2 3 4]  


# a.get_shape()  
# 返回的是TensorShape([Dimension(2), Dimension(3)]),
# 不能使用 sess.run() 因為返回的不是tensor 或string,而是元組  
x_shape=x.get_shape()  
x_shape=x.get_shape().as_list()  # 可以使用 as_list()得到具體的尺寸,x_shape=[2 3]  
y_shape=y.get_shape()  # AttributeError: 'list' object has no attribute 'get_shape'  
z_shape=z.get_shape()  # AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'  
或者a.shape.as_list()

 

tf.shape(x)

  tf.shape()中x數據類型可以是tensor,list,array,返回是一個tensor.

shape=tf.placeholder(tf.float32, shape=[None, 227,227,3] )

  我們經常會這樣來feed數據,如果在運行的時候想知道None到底是多少,這時候,只能通過tf.shape(x)[0]這種方式來獲得.

  由於返回的時tensor,所以我們可以使用其他tensorflow節點操作進行處理,如下面的轉置卷積中,使用stack來合並各個shape的分量,

def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
    with tf.variable_scope('conv_transpose'):

        shape = [kernel, kernel, output_filters, input_filters]
        weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')

        batch_size = tf.shape(x)[0]
        height = tf.shape(x)[1] * strides
        width = tf.shape(x)[2] * strides
        output_shape = tf.stack([batch_size, height, width, output_filters])
return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose')

tensor.get_shape()

  只有tensor有這個方法, 返回是一個tuple。也正是由於返回的是TensorShape([Dimension(2), Dimension(3)])這樣的元組,所以可以調用as_list化為[2, 3]樣list,或者get_shape()[i].value得到具體值.

tensor.set_shape()

  設置tensor的shape,一般不會用到,在tfrecode中,由於解析出來的tensor不會被設置shape,后續的函數是需要shape的維度等相關屬性的,所以這里會使用.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM