tf.split(dimension, num_split, input):dimension的意思就是輸入張量的哪一個維度,如果是0就表示對第0維度進行切割。num_split就是切割的數量,如果是2就表示輸入張量被切成2份,每一份是一個列表。
- import tensorflow as tf;
- import numpy as np;
- A = [[1,2,3],[4,5,6]]
- x = tf.split(1, 3, A)
- with tf.Session() as sess:
- c = sess.run(x)
- for ele in c:
- print ele
輸出:
[[1]
[4]]
[[2]
[5]]
[[3]
[6]]
TensorFlow函數:tf.split
由 Carrie 創建, 最后一次修改 2018-03-15
tf.split函數
split(
value,
num_or_size_splits,
axis=0, num=None, name='split' )
定義在:tensorflow/python/ops/array_ops.py。
參見指南:張量變換>切割和連接
將張量分割成子張量。
如果 num_or_size_splits 是整數類型,num_split,則 value 沿維度 axis 分割成為 num_split 更小的張量。要求 num_split 均勻分配 value.shape[axis]。
如果 num_or_size_splits 不是整數類型,則它被認為是一個張量 size_splits,然后將 value 分割成 len(size_splits) 塊。第 i 部分的形狀與 value 的大小相同,除了沿維度 axis 之外的大小 size_splits[i]。
例如:
# 'value' is a tensor with shape [5, 30] # Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1 split0, split1, split2 = tf.split(value, [4, 15, 11], 1) tf.shape(split0) # [5, 4] tf.shape(split1) # [5, 15] tf.shape(split2) # [5, 11] # Split 'value' into 3 tensors along dimension 1 split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1) tf.shape(split0) # [5, 10]
函數參數:
- value:要分割的 Tensor。
- num_or_size_splits:指示沿 split_dim 分割數量的 0-D 整數 Tensor 或包含沿 split_dim 每個輸出張量大小的 1-D 整數 Tensor ;如果為一個標量,那么它必須均勻分割 value.shape[axis];否則沿分割維度的大小總和必須與該 value 相匹配。
- axis:A 0-D int32 Tensor;表示分割的尺寸;必須在[-rank(value), rank(value))范圍內;默認為0。
- num:可選的,用於指定無法從 size_splits 的形狀推斷出的輸出數。
- name:操作的名稱(可選)。
函數返回值:
如果 num_or_size_splits 是標量,返回 num_or_size_splits Tensor對象;如果 num_or_size_splits 是一維張量,則返回由 value 分割產生的 num_or_size_splits.get_shape[0] Tensor對象。
函數可能引發的異常:
- ValueError:如果 num 沒有指定並且無法推斷。