Tensorflow Dataset.from_generator使用示例


 

shapes = (tf.TensorShape([None, None]), tf.TensorShape([10, 10]))
# 傳入的是一個generator,即返回字段為yield的函數,不可傳入嵌套生成器
# dataSet output_types參數必選,output_shapes參數可選,不選會直接適配數據的shape
# 參數就是一個元組
data_set = tf.data.Dataset.from_generator(gen_epochs,
                                          output_types=(tf.int32, tf.int32),
                                          output_shapes=shapes,
                                          args=(n, batch_size, 10))

 

之前的一篇博文(https://blog.csdn.net/foreseerwang/article/details/80170210)介紹了使用Tensorflow Dataset進行數據導入的方法及其優勢。最近在實際使用中越發感覺到這個方式非常好用,尤其是發現了.from_generator這個method。

 

關於Dataset.from_generator的簡單介紹,請參見如下兩個鏈接:

https://tensorflow.google.cn/versions/master/api_docs/python/tf/data/Dataset#repeat

https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369

 

注意,Dataset.from_generator在舊版Tensorflow中沒有,起碼在1.3版本tf.contrib.data.Dataset中還沒有,后來用的1.7版本就有了。

 

我們知道,tensorflow的基本原理是先構造一個計算圖,最后再統一計算。為此,tf重寫了幾乎所有常見函數,用於構造計算圖,而且tensorflow不支持循環、選擇等普通編程語言的常見操作。這就給編程使用帶來比較大的麻煩。具體到data feeding上,也是如此。雖然設計了placeholder、train.slice_input_producer系列、Dataset等多種方式,但使用中仍有各種不便,尤其是在輸入形式復雜、需要多重變換的時候更是如此。而Dataset.from_generator可以在一定程度上解決這個問題。

 

簡單的說,Dataset.from_generator可以使用普通編程語言編寫的外部子函數生成Dataset,這樣幾乎不受tensorflow編程不便的影響。先舉一個最簡單的示例:

'''
import pickle
fr=open('/media/dell/D/qcc/RandLA-Net/data/semantic_kitti/dataset/sequences_0.06/00/KDTree/000001.pkl','rb')
inf = pickle.load(fr)
doc = open('1.txt', 'a')
print(inf, file=doc)
print(inf)
'''

# demo of Dataset.from_generator
# blog.csdn.net/foreseerwang
# QQ: 50834

"""
Expected outputs:
Batch No. 0:
[0 1 2 3]
Batch No. 1:
[4 0 1 2]
Batch No. 2:
[3 4 0 1]
Batch No. 3:
[2 3 4]
end!
"""

import numpy as np
import tensorflow as tf


def data_generator():
    dataset = np.array(range(5))
    for d in dataset:
        #print(d)
        yield d


dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([])))
dataset = dataset.repeat(3) #3==epoch
dataset = dataset.batch(4) #4==batchsize

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()

with tf.Session() as sess:
    try:
        batch_num = 0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num)
            print(one_batch)
            print('')
            batch_num += 1

    except tf.errors.OutOfRangeError:
        print('end!')
        
        

很顯然,這個的輸出如下:

  1. Batch No. 0:
  2. [ 0 1 2 3]
  3.  
  4. Batch No. 1:
  5. [ 4 0 1 2]
  6.  
  7. Batch No. 2:
  8. [ 3 4 0 1]
  9.  
  10. Batch No. 3:
  11. [ 2 3 4]
  12.  
  13. end!

 

下面給出一個復雜的問題。假設需要輸入如下序列:

A B

A C B

C

其中A/B/C分別代表一個文件,例如一張圖片或是一個文本文件。每一行是一條記錄,按行讀入,並聚集多行形成batch,譬如每4行形成一個batch。這里有兩個難點:1.每一行/每一條記錄的元素長度不一樣;2.讀入元素A/B/C之后還要以之作為文件名讀入文件內容。現有各種data feeding方式似乎很難同時解決這兩個難點,除了Dataset.from_generator。

 

針對這個問題,使用Dataset.from_generator的一個簡化版示例如下:

  1. # demo of Dataset.from_generator
  2. # blog.csdn.net/foreseerwang
  3. # QQ: 50834
  4.  
  5. """
  6. Expected outputs:
  7.  
  8. Batch No. 0:
  9. [[ 1 2 3]
  10. [ 2 3 -1]]
  11.  
  12. Batch No. 1:
  13. [[ 3 -1 -1]
  14. [ 4 5 -1]]
  15.  
  16. Batch No. 2:
  17. [[ 6 7 8]
  18. [ 9 -1 -1]]
  19.  
  20. Batch No. 3:
  21. [[10 11 12]
  22. [13 14 -1]]
  23.  
  24. Batch No. 4:
  25. [[15 -1 -1]]
  26.  
  27. end!
  28. """
  29.  
  30. import io
  31. import numpy as np
  32. import tensorflow as tf
  33.  
  34. class DataFeeder:
  35.  
  36. def __init__(self, filenames):
  37. self.filenames = filenames
  38.  
  39. def file_readline(self):
  40. for filename in self.filenames:
  41. fr = io.open(filename, 'r', encoding='utf-8')
  42.  
  43. while True:
  44. file_line = fr.readline()
  45. if not file_line:
  46. break
  47.  
  48. datalist = file_line.split()
  49. # if datalist is a list of filename, file contents can
  50. # be read and appendded here.
  51. yield np.asarray(datalist, dtype='int32')
  52.  
  53. fr.close()
  54.  
  55. def generate_batch(self, batch_size, num_epochs=None):
  56. dataset = tf.data.Dataset.from_generator(self.file_readline,
  57. tf.int32,
  58. tf.TensorShape([ None]))
  59.  
  60. dataset = dataset.repeat(num_epochs)
  61. dataset = dataset.padded_batch(
  62. batch_size,
  63. padded_shapes=tf.TensorShape([ 3]),
  64. padding_values= -1)
  65.  
  66. iterator = dataset.make_one_shot_iterator()
  67. out_batch = iterator.get_next()
  68.  
  69. return out_batch
  70.  
  71. filenames = [ 'a.txt', 'b.txt', 'c.txt']
  72. data_feeder = DataFeeder(filenames)
  73. one_batch = data_feeder.generate_batch(batch_size= 2, num_epochs=1)
  74.  
  75. with tf.Session() as sess:
  76. try:
  77. batch_num = 0
  78. while True:
  79. data_batch = sess.run(one_batch)
  80. print( 'Batch No. %d:' % batch_num)
  81. print(data_batch)
  82. print( '')
  83. batch_num+= 1
  84.  
  85. except tf.errors.OutOfRangeError:
  86. print( 'end!')

 

其中三個文本文件a.txt/b.txt/c.txt的內容分別如下:

a.txt:

1 2 3 2 3 3

b.txt:

4 5 6 7 8 9

c.txt:

10 11 12 13 14 15

 

運行以上代碼的輸出為:

  1. Batch No. 0:
  2. [[ 1 2 3]
  3. [ 2 3 -1]]
  4.  
  5. Batch No. 1:
  6. [[ 3 -1 -1]
  7. [ 4 5 -1]]
  8.  
  9. Batch No. 2:
  10. [[ 6 7 8]
  11. [ 9 -1 -1]]
  12.  
  13. Batch No. 3:
  14. [[ 10 11 12]
  15. [ 13 14 -1]]
  16.  
  17. Batch No. 4:
  18. [[ 15 -1 -1]]
  19.  
  20. end!

目前的輸出,每個batch是batch_size * 3的矩陣。實際上,1~15的數字可以是某個圖片的文件名,在file_readline()函數中讀出這些數字后,可以繼續讀出這些文件的內容,並形成更高維度的Dataset輸出,譬如:batch_size * img_size * img_size * img_channel的Dataset。

 

最后,說幾點注意事項(詳見代碼):

1. generator函數不能有輸入參數,但如果是class內的一個函數,可以使用self參數,這也是傳遞參數的一個手段;

2. 上述class中,建議傳遞文件名,在generator中打開處理再關閉,而不應該在外面打開(fr=open(filename, ‘r’)),然后把fr傳遞給generator讀取。實踐表明:后面這種方法形成的dataset不能repeat;

3. 因為序列不等長,在形成dataset batch時需要使用Dataset.padded_batch方法。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM