tensorflow獲取shape


https://blog.csdn.net/TeFuirnever/article/details/88880350
https://blog.csdn.net/shuzfan/article/details/79051042

獲取tensor shape共有三中方式:x.shape、x.get_shape()、tf.shape(x)

x.shape:返回TensorShape類型,可以直接看到shape,打印出來的形式如:(3, 5)。注意x必須是tensor類型才能調用;

x.get_shape():同x.shape,是獲取x的shape屬性的函數;

tf.shape(x):返回的是Tensor類型,不能直接看到shape,只能看到shape的維數,如2。x可以不是tensor類型;另外對應維度為None的變量,想要獲得運行時實際的值,必須在運行時使用tf.shape(x)[0]方式獲得;

可以用 x.shape.as_list() 很方便獲取tensor的維度list,比如維度變換時:

# shape=[3, 2, 3]
array = [ [[1, 1, 1], [2, 2, 2]],
          [[3, 3, 3], [4, 4, 4]],
          [[5, 5, 5], [6, 6, 6]]
        ]

input = tf.constant(array)


shape_list = input.shape.as_list()
print(shape_list) # [3, 2, 3]

with tf.Session() as sess:
    output = tf.reshape(input, [-1,  shape_list[1]*shape_list[2]])
    print(sess.run(output)) # [[1 1 1 2 2 2]
                                 # [3 3 3 4 4 4]

input = [[1, 2, 3, 4, 5],
         [6, 7, 8, 9, 10],
         [11, 12, 13, 14, 15]
        ]

# input = tf.random_normal([32, 10, 8])

# 轉換為tensor
input2 = tf.constant(input)

print(input2.shape) # (3, 5)
print(tf.shape(input2)) # Tensor("Shape_1:0", shape=(3,), dtype=int32) 

# 維度為None情況
tensor_x = tf.placeholder(tf.int64, [None, 42], name='tensor_x')
print(tensor_x.shape) # (?, 42)
print(tf.shape(tensor_x)) # Tensor("Shape:0", shape=(2,), dtype=int32)

with tf.Session() as sess:
  print(tf.shape(tensor_x)) # tensor_x未賦值,  維度存在None,會報錯
  print(tf.shape(input2)) # tensor_x維度不存在None,不會報錯


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM