TensorFlow2.0——划分數據集


將數據划分成若干批次的數據,可以使用tf.train或者tf.data.Dataset中的方法。

1. tf.data.Dataset

(1)划分方法

# 下面是,數據批次划分

    batch_size = 10
    # 將訓練數據的特征和標簽組合,使用from_tensor_slices將數據放入隊列
    dataset = tfdata.Dataset.from_tensor_slices((features, labels))
    # 使用shuffle(),隨機打亂數據集順序,不用shuffle就是按順序划分,buffer_size 參數應大於等於樣本數
    # dataset = dataset.shuffle(buffer_size=num_examples)
    # batch把dataset按照batch_size分批次,得到一個list集合。默認drop_remainder=False時,保留不足批次的部分,如果是True,就是舍去。
    dataset = dataset.batch(batch_size)
    # dataset = dataset.batch(batch_size).repeat()  # repeat表示重復次數,默認是None,表示數據序列無限延續

 

# 輸出

    # 輸出所有batch的list集合。
    # print(list(dataset.as_numpy_iterator()))

    # 輸出其中一個batch,兩種方法,官方推薦way2!
    print("way1")
    data_iter = iter(dataset)
    for X, y in data_iter:
        print(X, y)
        break
    print("way2")
    for (batch_num, (X, y)) in enumerate(dataset):
        print((X, y))  # batch_num是批次號,標識符,也可以起其他名字
        break

 

(2)dataset.batch()方法說明

batch把dataset按照batch_size分批次,得到一個list集合。默認drop_remainder=False時,保留不足批次的部分,如果是True,就是舍去。
list(dataset.as_numpy_iterator())方法可以輸出所有batch的list集合。
  def batch(self, batch_size, drop_remainder=False):
    """Combines consecutive elements of this dataset into batches.

    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3)
    >>> list(dataset.as_numpy_iterator()) #這個方法可以輸出所有batch的list
    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]

    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3, drop_remainder=True)
    >>> list(dataset.as_numpy_iterator())
    [array([0, 1, 2]), array([3, 4, 5])]

(3)dataset.repeat()方法說明

  def repeat(self, count=None):
    """Repeats this dataset so each original value is seen `count` times.

    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
    >>> dataset = dataset.repeat(3)
    >>> list(dataset.as_numpy_iterator())
    [1, 2, 3, 1, 2, 3, 1, 2, 3]

    Note: If this dataset is a function of global state (e.g. a random number
    generator), then different repetitions may produce different elements.

    Args:
      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        number of times the dataset should be repeated. The default behavior (if `count` is `None` or `-1`) is for the dataset be repeated indefinitely.

    Returns:
      Dataset: A `Dataset`.
    """

 

2.tf.train

參考:https://www.cnblogs.com/jfl-xx/p/9945967.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM