有时,为了调试数据,需要将数据打印打出来,可以用interator来遍历数据
首先定义两个遍历函数,
def print_dataset(self, data_set): iterator = data_set.make_one_shot_iterator() next_element = iterator.get_next() num_batch = 0 with tf.train.MonitoredTrainingSession() as sess: while not sess.should_stop(): value = sess.run(next_element) num_batch += 1 print("Num Batch: ", num_batch) print("Batch value: ", value)
def print_dataset2(self, data_set): iterator = data_set.make_initializable_iterator() next_element = iterator.get_next() num_batch = 0 with tf.train.MonitoredTrainingSession() as sess: sess.run(iterator.initializer) while True: try: value = sess.run(next_element) print("Num Batch: ", num_batch) print("Batch value: ", value) #assert j == value #j += 1 num_batch += 1 except tf.errors.OutOfRangeError: break
第一个函数不支持lookup等操作,会报错
ValueError: Failed to create a one-shot iterator for a dataset. `Dataset.make_one_shot_iterator()` does not support datasets that capture stateful objects, such as a `Variable` or `LookupTable`. In these cases, use `Dataset.make_initializable_iterator()`. (Original error: Cannot capture a stateful node (name:hash_table, type:HashTableV2) by value.)
在这种情况下,使用第二个函数。
dataset = tf.data.TextLineDataset(file_list) print("=======len(file_list)=======",len(file_list)) print(dataset) self.print_dataset(dataset) self.text_set = self.text_set.map(lambda src, tgt: (self.case_table.lookup(src), self.case_table.lookup(tgt)) ).prefetch(buffer_size) self.print_dataset2(self.text_set)