一、幾個函數
- RandomShuffleQueue類
__init__(self, capacity, min_after_dequeue,dtypes, shapes=None,names=None, seed=None, shared_name=None, name="random_shuffle_queue")
queue = tf.RandomShuffleQueue(...):創建一個queue,按隨機順序進行dequeue
RandomShuffleQueue有一定的容量限制capacity,支持多個生產者和消費者
RandomShuffleQueue中的每個元素是固定長度的tensor 元組,數據類型由dtypes定義,形狀為shapes。如果shapes沒有定義,那么不同的queue元素可能有不同的形狀,此時就不能使用dqueue_many。如果shapes定義了,則所有的元素必須有相同的形狀
min_after_dequeue決定queue在dequeue以后要保持的元素個數,如果沒有足夠的元素,就會block住dequeue的相關操作,直到有足夠元素進來。當queue關閉,則這個參數被忽略
- enqueue(self, vals, name=None)
enqueue_op = queue.enqueue(...) 創建enqueue元素到queue中的操作
如果操作執行時queue是滿的,則會block住
vals是一個tensor或一個tensor的list/tuple,或者是一個字典,它相當於enqueue操作時的數據池
enqueue操作是要手動觸發的,也就是不是說像一般的那種計算,會把enqueue作為依賴操作被執行
- queue.dequeue(self, name=None)
從queue中取出一個元素
- Coordinator類
__init__(self, clean_stop_exception_types=None)
coord = tf.Coordinator() 協調線程的執行
- QueueRunner類
__init__(self, queue=None, enqueue_ops=None, close_op=None, cancel_op=None, queue_closed_exception_types=None,queue_runner_def=None, import_scope=None)
說明
qr = tf.train.QueueRunner(...) 為一個queue保持一系列enqueue操作,每個操作以一個線程執行
queue: a Queue
enqueue_ops: 一個enqueue ops列表
close_op: 指定關閉queue的操作
cancel_op:指定關閉以及取消掛起的enqueue ops的操作
- qr.create_threads(self, sess, coord=None, daemon=False, start=False)
為給定的sess創建多個線程以執行enqueue ops
start:如果為False,則需要手動調用 start()來啟動
- start_queue_runners
start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=ops.GraphKeys.QUEUE_RUNNERS)
tf.train.start_queue_runners(...) 啟動圖中所有的queue runners,與add_queue_runner()配合使用
start: `False`只是創建線程,但是沒有啟動
二、實例
1 def example1(): 2 3 """ 4 最簡單的例子,只使用enqueue和dequeue 5 :return: 6 """ 7 example = tf.constant(2, "float32", [2, 2]) 8 # 創建一個queue 9 # tf.RandomShuffleQueue(capacity,: queue的容量 10 # min_after_dequeue, : 保證queue中最少的個數 11 # dtypes, 12 # shapes=None,...) 13 queue = tf.RandomShuffleQueue(10, 0, "float32", shapes=[2, 2]) 14 # 為queue添加enqueue操作 15 enqueue_op = queue.enqueue(example) 16 # 為queue添加dequeue操作 17 inputs = queue.dequeue() 18 with tf.Session() as sess: 19 sess.run(tf.global_variables_initializer()) 20 sess.run(enqueue_op) 21 print(sess.run(inputs))
1 def example2(): 2 """ 3 使用queue runner來管理多個enqueue線程,用coord來關閉線程 4 :return: 5 """ 6 data = tf.constant(2, "float32", [2, 2]) 7 example = [data, data, data, data, data, data, data, data] 8 queue = tf.RandomShuffleQueue(10, 0, "float32", shapes=[2, 2]) 9 enqueue_op = queue.enqueue(example) 10 11 qr = tf.train.QueueRunner(queue, [enqueue_op] * 4) 12 coord = tf.train.Coordinator() 13 14 inputs = queue.dequeue() 15 with tf.Session() as sess: 16 threads = qr.create_threads(sess, coord, start=True) 17 sess.run(tf.global_variables_initializer()) 18 print(sess.run(inputs)) 19 # 用coord來停止所有的enqueu線程 20 coord.request_stop() 21 coord.join(threads)
