有時,為了調試數據,需要將數據打印打出來,可以用interator來遍歷數據
首先定義兩個遍歷函數,
def print_dataset(self, data_set): iterator = data_set.make_one_shot_iterator() next_element = iterator.get_next() num_batch = 0 with tf.train.MonitoredTrainingSession() as sess: while not sess.should_stop(): value = sess.run(next_element) num_batch += 1 print("Num Batch: ", num_batch) print("Batch value: ", value)
def print_dataset2(self, data_set): iterator = data_set.make_initializable_iterator() next_element = iterator.get_next() num_batch = 0 with tf.train.MonitoredTrainingSession() as sess: sess.run(iterator.initializer) while True: try: value = sess.run(next_element) print("Num Batch: ", num_batch) print("Batch value: ", value) #assert j == value #j += 1 num_batch += 1 except tf.errors.OutOfRangeError: break
第一個函數不支持lookup等操作,會報錯
ValueError: Failed to create a one-shot iterator for a dataset. `Dataset.make_one_shot_iterator()` does not support datasets that capture stateful objects, such as a `Variable` or `LookupTable`. In these cases, use `Dataset.make_initializable_iterator()`. (Original error: Cannot capture a stateful node (name:hash_table, type:HashTableV2) by value.)
在這種情況下,使用第二個函數。
dataset = tf.data.TextLineDataset(file_list) print("=======len(file_list)=======",len(file_list)) print(dataset) self.print_dataset(dataset) self.text_set = self.text_set.map(lambda src, tgt: (self.case_table.lookup(src), self.case_table.lookup(tgt)) ).prefetch(buffer_size) self.print_dataset2(self.text_set)