tensorflow之tfrecord數據讀取
Tensorflow關於TFRecord格式文件的處理、模型的訓練的架構為:
1、獲取文件列表、創建文件隊列:http://blog.csdn.net/lovelyaiq/article/details/78711944(tfrecord格式,保存,讀取)
2、圖像預處理:http://blog.csdn.net/lovelyaiq/article/details/78716325
3、合成Batch:http://blog.csdn.net/lovelyaiq/article/details/78727189
4、設計損失函數、梯度下降算法:http://blog.csdn.net/lovelyaiq/article/details/78616736
- 首先了解tfrecord的格式;TensorFlow提供了TFRecord的格式統一管理存儲數據
# tf.train.Example message Example{ Features features = 1; } message Features{ map<string,Features> feature = 1; } message Feature { oneof kind { BytesList bytes_list = 1; FloateList float_list = 2; Int64List int64_list = 3; } }
從定義中可以看出tf.train.Example是以字典的形式存儲數據格式,string為字典的key值,字典的屬性值有三種類型:bytes、float、int64。接下來通過例子說明如果通過TFRecord保存和讀取文件。保存和讀取用到函數分別為:tf.python_io.TFRecordWriter和tf.TFRecordReader()。
然后將原數據轉換為tfrecord;(參考tensorflow/models/deeplab api)
def _int64_list_feature(values): """Returns a TF-Feature of int64_list. Args: values: A scalar or list of values. Returns: A TF-Feature. """ if not isinstance(values, collections.Iterable): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def _bytes_list_feature(values): """Returns a TF-Feature of bytes. Args: values: A string. Returns: A TF-Feature. """ def norm2bytes(value): return value.encode() if isinstance(value, str) and six.PY3 else value return tf.train.Feature( bytes_list=tf.train.BytesList(value=[norm2bytes(values)])) def image_seg_to_tfexample(image_data, filename, height, width, seg_data): """Converts one image/segmentation pair to tf example. Args: image_data: string of image data. filename: image filename. height: image height. width: image width. seg_data: string of semantic segmentation data. Returns: tf example of one image/segmentation pair. """ return tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': _bytes_list_feature(image_data), 'image/filename': _bytes_list_feature(filename), 'image/format': _bytes_list_feature( _IMAGE_FORMAT_MAP[FLAGS.image_format]), 'image/height': _int64_list_feature(height), 'image/width': _int64_list_feature(width), 'image/channels': _int64_list_feature(3), 'image/segmentation/class/encoded': ( _bytes_list_feature(seg_data)), 'image/segmentation/class/format': _bytes_list_feature( FLAGS.label_format), })) def _convert_dataset(dataset_split): """Converts the specified dataset split to TFRecord format. Args: dataset_split: The dataset split (e.g., train, val). Raises: RuntimeError: If loaded image and label have different shape, or if the image file with specified postfix could not be found. """ image_files = _get_files('image', dataset_split) //得到文件列表 label_files = _get_files('label', dataset_split) num_images = len(image_files) num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS))) image_reader = build_data.ImageReader('png', channels=3) label_reader = build_data.ImageReader('png', channels=1) for shard_id in range(_NUM_SHARDS): //保存_NUM_SHARDS個tfrecord文件 shard_filename = '%s-%05d-of-%05d.tfrecord' % ( dataset_split, shard_id, _NUM_SHARDS) output_filename = os.path.join(FLAGS.output_dir, shard_filename) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: start_idx = shard_id * num_per_shard end_idx = min((shard_id + 1) * num_per_shard, num_images) for i in range(start_idx, end_idx): //逐個文件進行讀取轉換; sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( i + 1, num_images, shard_id)) sys.stdout.flush() # Read the image. image_data = tf.gfile.FastGFile(image_files[i], 'rb').read() height, width = image_reader.read_image_dims(image_data) # Read the semantic segmentation annotation. seg_data = tf.gfile.FastGFile(label_files[i], 'rb').read() seg_height, seg_width = label_reader.read_image_dims(seg_data) if height != seg_height or width != seg_width: raise RuntimeError('Shape mismatched between image and label.') # Convert to tf example. re_match = _IMAGE_FILENAME_RE.search(image_files[i]) if re_match is None: raise RuntimeError('Invalid image filename: ' + image_files[i]) filename = os.path.basename(re_match.group(1)) example = build_data.image_seg_to_tfexample( image_data, filename, height, width, seg_data) tfrecord_writer.write(example.SerializeToString()) sys.stdout.write('\n') sys.stdout.flush() def main(unused_argv): # Only support converting 'train' and 'val' sets for now. for dataset_split in ['train', 'val']: _convert_dataset(dataset_split)
然后在train時讀取;(分三種,一種原始讀取,一種tf.data.TFRecordDataset,一種用slim實現)分別參考:
- http://blog.csdn.net/lovelyaiq/article/details/78711944(tfrecord格式,保存,讀取)
- tensorflow入門:tfrecord 和tf.data.TFRecordDataset
- https://github.com/NanqingD/DeepLabV3-Tensorflow/blob/master/libs/datasets/dataset_factory.py
- 參考deeplab實現:
import tensorflow as tf slim = tf.contrib.slim dataset = slim.dataset tfexample_decoder = slim.tfexample_decoder def get_dataset(dataset_name, split_name, dataset_dir): """Gets an instance of slim Dataset. Args: dataset_name: Dataset name. split_name: A train/val Split name. dataset_dir: The directory of the dataset sources. Returns: An instance of slim Dataset. Raises: ValueError: if the dataset_name or split_name is not recognized. """ if dataset_name not in _DATASETS_INFORMATION: raise ValueError('The specified dataset is not supported yet.') splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes if split_name not in splits_to_sizes: raise ValueError('data split name %s not recognized' % split_name) # Prepare the variables for different datasets. num_classes = _DATASETS_INFORMATION[dataset_name].num_classes ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label file_pattern = _FILE_PATTERN file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Specify how the TF-Examples are decoded. keys_to_features = { 'image/encoded': tf.FixedLenFeature( (), tf.string, default_value=''), 'image/filename': tf.FixedLenFeature( (), tf.string, default_value=''), 'image/format': tf.FixedLenFeature( (), tf.string, default_value='jpeg'), 'image/height': tf.FixedLenFeature( (), tf.int64, default_value=0), 'image/width': tf.FixedLenFeature( (), tf.int64, default_value=0), 'image/segmentation/class/encoded': tf.FixedLenFeature( (), tf.string, default_value=''), 'image/segmentation/class/format': tf.FixedLenFeature( (), tf.string, default_value='png'), } items_to_handlers = { 'image': tfexample_decoder.Image( image_key='image/encoded', format_key='image/format', channels=3), 'image_name': tfexample_decoder.Tensor('image/filename'), 'height': tfexample_decoder.Tensor('image/height'), 'width': tfexample_decoder.Tensor('image/width'), 'labels_class': tfexample_decoder.Image( image_key='image/segmentation/class/encoded', format_key='image/segmentation/class/format', channels=1), } decoder = tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) return dataset.Dataset( data_sources=file_pattern, reader=tf.TFRecordReader, decoder=decoder, num_samples=splits_to_sizes[split_name], items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, ignore_label=ignore_label, num_classes=num_classes, name=dataset_name, multi_label=True)
- 再經過:
...... data_provider = dataset_data_provider.DatasetDataProvider( dataset, num_readers=num_readers, num_epochs=None if is_training else 1, shuffle=is_training) image, label, image_name, height, width = _get_data(data_provider, dataset_split) if label is not None: if label.shape.ndims == 2: label = tf.expand_dims(label, 2) elif label.shape.ndims == 3 and label.shape.dims[2] == 1: pass else: raise ValueError('Input label shape must be [height, width], or ' '[height, width, 1].') label.set_shape([None, None, 1]) original_image, image, label = input_preprocess.preprocess_image_and_label( image, label, crop_height=crop_size[0], crop_width=crop_size[1], min_resize_value=min_resize_value, max_resize_value=max_resize_value, resize_factor=resize_factor, min_scale_factor=min_scale_factor, max_scale_factor=max_scale_factor, scale_factor_step_size=scale_factor_step_size, ignore_label=dataset.ignore_label, is_training=is_training, model_variant=model_variant) sample = { common.IMAGE: image, common.IMAGE_NAME: image_name, common.HEIGHT: height, common.WIDTH: width } if label is not None: sample[common.LABEL] = label if not is_training: # Original image is only used during visualization. sample[common.ORIGINAL_IMAGE] = original_image, num_threads = 1 return tf.train.batch( sample, batch_size=batch_size, num_threads=num_threads, capacity=32 * batch_size, allow_smaller_final_batch=not is_training, dynamic_pad=True)
- 主要討論tensorflow的tfrecord讀取方法;及slim讀取數據;
def read_data(is_training, split_name): file_pattern = '{}_{}.tfrecord'.format(args.data_name, split_name) tfrecord_path = os.path.join(args.data_dir,'records',file_pattern) if is_training: dataset = get_dataset(tfrecord_path) //通過slim方式讀取tfrecord; image, gt_mask = extract_batch(dataset, args.batch_size, is_training) else: image, gt_mask = read_tfrecord(tfrecord_path) //通過原始方式讀取tfrecord; image, gt_mask = preprocess.preprocess_image(image, gt_mask, is_training) return image, gt_mask
1. 數據處理流程
對於輸入數據的處理,大體上流程都差不多,可以歸結如下:
將數據轉為 TFRecord 格式的多個文件
用 tf.train.match_filenames_once() 創建文件列表
用 tf.train.string_input_producer() 創建輸入文件隊列,可以將輸入文件順序隨機打亂
用 tf.TFRecordReader() 讀取文件中的數據
用 tf.parse_single_example() 解析數據
對數據進行解碼及預處理
用 tf.train.shuffle_batch() 將數據組合成 batch
將 batch 用於訓練
2. 輸入數據處理框架
框架主要是三方面的內容:
TFRecord 輸入數據格式
圖像數據處理
多線程輸入數據處理
3. reference: