tensorflow-- Dataset創建數據集對象


tf.data模塊包含:

  •  experimental 模塊
  •  Dataset 類
  •  FixedLengthRecordDataset 類
  • TFRecordDataset 類
  • TextLineDataset 類
 
        
 1 #  author by FH.
 2 #  OverView:
 3 #  tf.data
 4 #           experimental  ---Modules
 5 #           Dataset      ---class
 6 #           FixedLengthRecordDataset  ---class
 7 #           TFRecordDataset           ---class
 8 #           TextLineDataset           ---class
 9 import tensorflow as tf
10 import numpy as np
11 
12 
13 # 1. 使用靜態方法 tf.data.Dataset.from_tensor_slices
14 #       將輸入的第一個維度切割,形成dataset
15 # 2. 使用 Dataset的 make_one_shot_iterator() 實例化一個 iterator
16 #       這個iterator 只能從頭到尾讀取一次。“one shot iterator”
17 def test1():
18     sess = tf.Session()
19     dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
20     dataset2 = tf.data.Dataset.from_tensor_slices(np.array([[1,2],[3,4],[0,9]]))
21     dataset3 = tf.data.Dataset.from_tensor_slices(
22         {
23             "a":np.array([1.0,2,3,4,5.0]),
24             "b":np.random.uniform(size=(5,2))
25         }
26     )
27     # 使用 Dataset的 make_one_shot_iterator() 實例化一個 iterator
28     #     這個iterator 只能從頭到尾讀取一次。“one shot iterator”
29     oneShotIterator1 = dataset1.make_one_shot_iterator()
30     oneShotIterator2 = dataset2.make_one_shot_iterator()
31     oneShotIterator3 = dataset3.make_one_shot_iterator()
32     element1 = oneShotIterator1.get_next()
33     element2 = oneShotIterator2.get_next()
34     element3 = oneShotIterator3.get_next()
35     for i in range(5):
36         print(sess.run(element1))
37     for i in range(3):
38         print(sess.run(element2))
39     for i in range(5):
40         print(sess.run(element3))
41     sess.close()
42 
43 # 1.Dataset 中的數據元素轉換。
44 #           map() :參數為一個函數,將dataset中的每個元素帶入獲取新的值
45 #           batch(): 參數為一個整數,將多個元素組合成一個batch
46 def test2():
47     sess = tf.Session()
48     dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0,6]))
49     # map() 重新映射新的元素值
50     dataset1 = dataset.map(lambda x: x * 3)
51     # batch()  2個組成一個batch, 組成batch 之后size 為3
52     dataset2 = dataset.batch(2)
53     # shuffle() 打亂dataset
54     dataset3 = dataset.shuffle(buffer_size=3)
55     # repeat()  將整個序列重復多次,重復4次 size 為24
56     dataset4 = dataset.repeat(4)
57 
58     oneShotIterator1 = dataset1.make_one_shot_iterator()
59     oneShotIterator2 = dataset2.make_one_shot_iterator()
60     oneShotIterator3 = dataset3.make_one_shot_iterator()
61     oneShotIterator4 = dataset4.make_one_shot_iterator()
62     element1 = oneShotIterator1.get_next()
63     element2 = oneShotIterator2.get_next()
64     element3 = oneShotIterator3.get_next()
65     element4 = oneShotIterator4.get_next()
66     for i in range(6):  # map()
67         print(sess.run(element1))
68     for i in range(3):  # batch()
69         print(sess.run(element2))
70     for i in range(6):  # shuffle()
71         print(sess.run(element3))
72     for i in range(24): # repeat()
73         print(sess.run(element4))
74     sess.close()
75 
76 # example1: 讀取圖片和相應的標簽並打亂,組成
77 #          batch_size=2 的數據集,重復10 epoch
78 def _parse_function(imgfilename,label):
79     image_value = tf.read_file(imgfilename)
80     img = tf.image.decode_image(image_value)
81     img = tf.image.resize_images(img,[256,256])
82     return img,label
83 def example1():
84     # 圖片列表
85     filesnames = tf.constant(['name1.jpg','name3.jpg','name5.jpg','name6.jpg','name7.jpg','name8.jpg'])
86     # 對應標簽
87     labels = tf.constant([0,1,0,1,1,0])
88     # dataset  (名稱,標簽)
89     dataset = tf.data.Dataset.from_tensor_slices((filesnames,labels))
90     # map 映射成圖片和標簽
91     dataset = dataset.map(_parse_function)
92     # shuffle ,batch , repeat
93     dataset = dataset.shuffle(buffersize=3).batch(2).repeat(10)
94     return dataset
95 
96 if __name__ == '__main__':
97     test2()
View Code

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM