shapes = (tf.TensorShape([None, None]), tf.TensorShape([10, 10])) # 傳入的是一個generator,即返回字段為yield的函數,不可傳入嵌套生成器 # dataSet output_types參數必選,output_shapes參數可選,不選會直接適配數據的shape # 參數就是一個元組 data_set = tf.data.Dataset.from_generator(gen_epochs, output_types=(tf.int32, tf.int32), output_shapes=shapes, args=(n, batch_size, 10))
之前的一篇博文(https://blog.csdn.net/foreseerwang/article/details/80170210)介紹了使用Tensorflow Dataset進行數據導入的方法及其優勢。最近在實際使用中越發感覺到這個方式非常好用,尤其是發現了.from_generator這個method。
關於Dataset.from_generator的簡單介紹,請參見如下兩個鏈接:
https://tensorflow.google.cn/versions/master/api_docs/python/tf/data/Dataset#repeat
https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369
注意,Dataset.from_generator在舊版Tensorflow中沒有,起碼在1.3版本tf.contrib.data.Dataset中還沒有,后來用的1.7版本就有了。
我們知道,tensorflow的基本原理是先構造一個計算圖,最后再統一計算。為此,tf重寫了幾乎所有常見函數,用於構造計算圖,而且tensorflow不支持循環、選擇等普通編程語言的常見操作。這就給編程使用帶來比較大的麻煩。具體到data feeding上,也是如此。雖然設計了placeholder、train.slice_input_producer系列、Dataset等多種方式,但使用中仍有各種不便,尤其是在輸入形式復雜、需要多重變換的時候更是如此。而Dataset.from_generator可以在一定程度上解決這個問題。
簡單的說,Dataset.from_generator可以使用普通編程語言編寫的外部子函數生成Dataset,這樣幾乎不受tensorflow編程不便的影響。先舉一個最簡單的示例:
''' import pickle fr=open('/media/dell/D/qcc/RandLA-Net/data/semantic_kitti/dataset/sequences_0.06/00/KDTree/000001.pkl','rb') inf = pickle.load(fr) doc = open('1.txt', 'a') print(inf, file=doc) print(inf) ''' # demo of Dataset.from_generator # blog.csdn.net/foreseerwang # QQ: 50834 """ Expected outputs: Batch No. 0: [0 1 2 3] Batch No. 1: [4 0 1 2] Batch No. 2: [3 4 0 1] Batch No. 3: [2 3 4] end! """ import numpy as np import tensorflow as tf def data_generator(): dataset = np.array(range(5)) for d in dataset: #print(d) yield d dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([]))) dataset = dataset.repeat(3) #3==epoch dataset = dataset.batch(4) #4==batchsize iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: try: batch_num = 0 while True: one_batch = sess.run(one_element) print('Batch No. %d:' % batch_num) print(one_batch) print('') batch_num += 1 except tf.errors.OutOfRangeError: print('end!')
很顯然,這個的輸出如下:
-
Batch No. 0:
-
[ 0 1 2 3]
-
-
Batch No. 1:
-
[ 4 0 1 2]
-
-
Batch No. 2:
-
[ 3 4 0 1]
-
-
Batch No. 3:
-
[ 2 3 4]
-
-
end!
下面給出一個復雜的問題。假設需要輸入如下序列:
A B
A C B
C
…
其中A/B/C分別代表一個文件,例如一張圖片或是一個文本文件。每一行是一條記錄,按行讀入,並聚集多行形成batch,譬如每4行形成一個batch。這里有兩個難點:1.每一行/每一條記錄的元素長度不一樣;2.讀入元素A/B/C之后還要以之作為文件名讀入文件內容。現有各種data feeding方式似乎很難同時解決這兩個難點,除了Dataset.from_generator。
針對這個問題,使用Dataset.from_generator的一個簡化版示例如下:
-
# demo of Dataset.from_generator
-
# blog.csdn.net/foreseerwang
-
# QQ: 50834
-
-
"""
-
Expected outputs:
-
-
Batch No. 0:
-
[[ 1 2 3]
-
[ 2 3 -1]]
-
-
Batch No. 1:
-
[[ 3 -1 -1]
-
[ 4 5 -1]]
-
-
Batch No. 2:
-
[[ 6 7 8]
-
[ 9 -1 -1]]
-
-
Batch No. 3:
-
[[10 11 12]
-
[13 14 -1]]
-
-
Batch No. 4:
-
[[15 -1 -1]]
-
-
end!
-
"""
-
-
import io
-
import numpy as np
-
import tensorflow as tf
-
-
class DataFeeder:
-
-
def __init__(self, filenames):
-
self.filenames = filenames
-
-
def file_readline(self):
-
for filename in self.filenames:
-
fr = io.open(filename, 'r', encoding='utf-8')
-
-
while True:
-
file_line = fr.readline()
-
if not file_line:
-
break
-
-
datalist = file_line.split()
-
# if datalist is a list of filename, file contents can
-
# be read and appendded here.
-
yield np.asarray(datalist, dtype='int32')
-
-
fr.close()
-
-
def generate_batch(self, batch_size, num_epochs=None):
-
dataset = tf.data.Dataset.from_generator(self.file_readline,
-
tf.int32,
-
tf.TensorShape([ None]))
-
-
dataset = dataset.repeat(num_epochs)
-
dataset = dataset.padded_batch(
-
batch_size,
-
padded_shapes=tf.TensorShape([ 3]),
-
padding_values= -1)
-
-
iterator = dataset.make_one_shot_iterator()
-
out_batch = iterator.get_next()
-
-
return out_batch
-
-
filenames = [ 'a.txt', 'b.txt', 'c.txt']
-
data_feeder = DataFeeder(filenames)
-
one_batch = data_feeder.generate_batch(batch_size= 2, num_epochs=1)
-
-
with tf.Session() as sess:
-
try:
-
batch_num = 0
-
while True:
-
data_batch = sess.run(one_batch)
-
print( 'Batch No. %d:' % batch_num)
-
print(data_batch)
-
print( '')
-
batch_num+= 1
-
-
except tf.errors.OutOfRangeError:
-
print( 'end!')
其中三個文本文件a.txt/b.txt/c.txt的內容分別如下:
a.txt:
1 2 3 2 3 3
b.txt:
4 5 6 7 8 9
c.txt:
10 11 12 13 14 15
運行以上代碼的輸出為:
-
Batch No. 0:
-
[[ 1 2 3]
-
[ 2 3 -1]]
-
-
Batch No. 1:
-
[[ 3 -1 -1]
-
[ 4 5 -1]]
-
-
Batch No. 2:
-
[[ 6 7 8]
-
[ 9 -1 -1]]
-
-
Batch No. 3:
-
[[ 10 11 12]
-
[ 13 14 -1]]
-
-
Batch No. 4:
-
[[ 15 -1 -1]]
-
-
end!
目前的輸出,每個batch是batch_size * 3的矩陣。實際上,1~15的數字可以是某個圖片的文件名,在file_readline()函數中讀出這些數字后,可以繼續讀出這些文件的內容,並形成更高維度的Dataset輸出,譬如:batch_size * img_size * img_size * img_channel的Dataset。
最后,說幾點注意事項(詳見代碼):
1. generator函數不能有輸入參數,但如果是class內的一個函數,可以使用self參數,這也是傳遞參數的一個手段;
2. 上述class中,建議傳遞文件名,在generator中打開處理再關閉,而不應該在外面打開(fr=open(filename, ‘r’)),然后把fr傳遞給generator讀取。實踐表明:后面這種方法形成的dataset不能repeat;
3. 因為序列不等長,在形成dataset batch時需要使用Dataset.padded_batch方法。