tf.split


tf.split(dimension, num_split, input):dimension的意思就是輸入張量的哪一個維度,如果是0就表示對第0維度進行切割。num_split就是切割的數量,如果是2就表示輸入張量被切成2份,每一份是一個列表。

  1. import tensorflow as tf;  
  2. import numpy as np;  
  3.   
  4. A = [[1,2,3],[4,5,6]]  
  5. x = tf.split(1, 3, A)  
  6.   
  7. with tf.Session() as sess:  
  8.     c = sess.run(x)  
  9.     for ele in c:  
  10.         print ele  

 

輸出:

 

[[1]
 [4]]
[[2]
 [5]]
[[3]
 [6]]

 

 

 

TensorFlow函數:tf.split

由 Carrie 創建, 最后一次修改 2018-03-15

tf.split函數

split(
    value,
    num_or_size_splits,
    axis=0, num=None, name='split' )

定義在:tensorflow/python/ops/array_ops.py

參見指南:張量變換>切割和連接

將張量分割成子張量。 

如果 num_or_size_splits 是整數類型,num_split,則 value 沿維度 axis 分割成為 num_split 更小的張量。要求 num_split 均勻分配 value.shape[axis]。

如果 num_or_size_splits 不是整數類型,則它被認為是一個張量 size_splits,然后將 value 分割成 len(size_splits) 塊。第 i 部分的形狀與 value 的大小相同,除了沿維度 axis 之外的大小 size_splits[i]。

例如:

# 'value' is a tensor with shape [5, 30] # Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1 split0, split1, split2 = tf.split(value, [4, 15, 11], 1) tf.shape(split0) # [5, 4] tf.shape(split1) # [5, 15] tf.shape(split2) # [5, 11] # Split 'value' into 3 tensors along dimension 1 split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1) tf.shape(split0) # [5, 10]

函數參數:

  • value:要分割的 Tensor。
  • num_or_size_splits:指示沿 split_dim 分割數量的 0-D 整數 Tensor 或包含沿 split_dim 每個輸出張量大小的 1-D 整數 Tensor ;如果為一個標量,那么它必須均勻分割 value.shape[axis];否則沿分割維度的大小總和必須與該 value 相匹配。
  • axis:A 0-D int32 Tensor;表示分割的尺寸;必須在[-rank(value), rank(value))范圍內;默認為0。
  • num:可選的,用於指定無法從 size_splits 的形狀推斷出的輸出數。
  • name:操作的名稱(可選)。

函數返回值:

如果 num_or_size_splits 是標量,返回 num_or_size_splits Tensor對象;如果 num_or_size_splits 是一維張量,則返回由 value 分割產生的 num_or_size_splits.get_shape[0] Tensor對象。

函數可能引發的異常:

  • ValueError:如果 num 沒有指定並且無法推斷。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM