tensorflow - 關於形狀、維度的操作匯總


size

Tensor 的 大小,長 * 寬;

tf.size 返回 Tensor,需要 session;

d1 = tf.random_uniform((3, 2))
# print(d1.size)    # AttributeError: 'Tensor' object has no attribute 'size'
size = tf.size(d1)
sess = tf.Session()
print(sess.run(size))       # 6

 

shape 和 tf.shape 和 get_shape  set_shape

 

先說結論再看例子

1. 運行環境不同

  • shape 和 get_shape 返回元組,故無需 Session,可直接獲取;
  • 而 tf.shape 返回 Tensor,需要 Session            【只有返回 Tensor 才需要 Session】

2. 適用對象不同

  • tf.shape 適用於 Tensor,還有 ndarray,list;
  • shape 適用於 Tensor,還有 ndarray;
  • get_shape 只適用於 Tensor;

 

代碼如下

########## tf.shape ##########
### 用函數獲取,返回 Tensor
# 針對所有 Tensor,包括 Variable,array、list 也可以
d5 = tf.shape(tf.random_normal((2, 3)))     ### Tensor
print(d5)       # Tensor("Shape:0", shape=(2,), dtype=int32)
d6 = tf.shape(tf.Variable([1. ,2.]))        ### Variable
n3 = tf.shape(np.array([[1, 2], [3, 4]]))   ### ndarray
n4 = tf.shape([1, 2])                       ### list
with tf.Session() as sess1:
    print(sess1.run(d5))        # [2 3]
    print(sess1.run(d6))        # [2]
    print(sess1.run(n3))        # [2 2]
    print(sess1.run(n4))        # [2]
    

########## shape ##########
### 直接獲取,返回元組
# 針對所有 Tensor,包括 Variable,array 也可以
d1 = tf.random_uniform((3, 2)).shape        ### Tensor
print(d1)       # (3, 2)
d2 = tf.Variable([1. ,2.]).shape            ### Variable
print(d2)       # (2,)
n1 = np.array([[1, 2], [3, 4]]).shape       ### ndarray
print(n1)       # (2, 2)


########## get_shape ##########
### 直接獲取,返回元組
# 針對所有 Tensor,包括 Variable,不包括 array
d3 = tf.random_uniform((3, 2)).get_shape()  ### Tensor
print(d3)       # (3, 2)
d4 = tf.Variable([1. ,2.]).get_shape()      ### Variable
print(d4)       # (2,)
# n2 = np.array([[1, 2], [3, 4]]).get_shape()     ### 報錯 AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

 

set_shape 是為一個 Tensor reset shape

一般只用於設置 placeholder 的尺寸;

x1 = tf.placeholder(tf.int32)
x1.set_shape([2, 2])
print(tf.shape(x1))

sess = tf.Session()
# print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1,2,3]]}))       ### ValueError: Cannot feed value of shape (1, 4) for Tensor 'Placeholder:0', which has shape '(2, 2)'
print(sess.run(tf.shape(x1), feed_dict={x1: [[0, 1], [2, 3]]}))

限制 x1 只能是 (2,2) 的 shape;

 

tf.squeeze 和 tf.expand_dims

def squeeze(input, axis=None, name=None, squeeze_dims=None)

壓縮維度,如果被壓縮的維度為 1 維,就去掉該維度,如果該維度不是 1 維,報錯

# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
  tf.shape(tf.squeeze(t))  # [2, 3]

# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
  tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]

axis 和 squeeze_dims 是一個意思, squeeze_dims 已被廢棄;

axis 可取 list 指定多個維度;

c1 = tf.constant([[1, 3]])
print(c1.shape)         # (1, 2)

### 第 0 維 的維度為 1
c2 = tf.squeeze(c1, squeeze_dims=0)
print(c2.shape)         # (2,)

c3 = tf.squeeze(c1, axis=0)
print(c3.shape)         # (2,)

 

def expand_dims(input, axis=None, name=None, dim=None)

在指定維度上增加 1 個維度

axis 和 dim 是一個意思,dim 已被廢棄

data = tf.constant([[1, 2],
                    [3, 4]])
print(data.shape)   # (2, 2)            ### 兩個維度

data2 = tf.expand_dims(data, dim=1)     ### 在第一個維度上添加一個維度
print(data2.shape)  # (2, 1, 2)

data3 = tf.expand_dims(data, dim=0)     ### 在第0個維度上添加一個維度
print(data3.shape)  # (1, 2, 2)

data4 = tf.expand_dims(data, dim=-1)    ### 在最后一個維度上添加一個維度
print(data4.shape)  # (2, 2, 1)

 

tf.concat

按指定維度進行拼接 

def concat(values, axis, name="concat") 

axis 0 表示按列拼接,1 表示按行拼接

d1 = tf.zeros((2, 3))
d2 = tf.ones((2, 4))

d3 = tf.concat([d1, d2], axis=1)        # 第 1 個維度
d4 = tf.concat([d1, d2], axis=-1)       # -1 代表最后一個維度

sess = tf.Session()
sess.run(d3)
sess.run(d4)
print(d3.shape)     # (2, 7)
print(d4.shape)     # (2, 7)

 

 

 

參考資料:

https://blog.csdn.net/m0_37744293/article/details/78254691  tf.shape()與tf.get_shape()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM