將VOC2012轉換為tfrecord


PASCAL-VOC2012簡介

PASCAL-VOC2012數據集介紹官網:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html ,數據集下載地址:benchmark_RELEASE:下載地址 voc2012:下載地址

VOC2012數據集分為20類,包括背景為21類,分別如下: 

  • Person: person 
  • Animal: bird, cat, cow, dog, horse, sheep 
  • Vehicle: aeroplane, bicycle, boat, bus, car, motorbike, train 
  • Indoor: bottle, chair, dining table, potted plant, sofa, tv/monitor

 

 

再看一下VOC2012數據集里有哪些文件夾:

 

 

在目標檢測中,主要用到了 Annotations,ImageSets,JPEGImages,其中 ImageSets/Main/ 保存了具體數據集的索引,Annotations 保存了標簽數據, JPEGImages 保存了圖片內容。

ImageSets/Main/ 文件夾以 , {class}_trainval.txt {class}_val.txt 的格式命名。 train.txt val.txt 例外,包括 Action,Layout,Main,Segmentation 四個文件夾:

  • Action:存放的是人的動作(例如running、jumping等等,這也是VOC challenge的一部分)
  • Layout:存放的是具有人體部位的數據(人的head、hand、feet等等,這也是VOC challenge的一部分
  • Main:存放的是圖像物體識別的數據,總共分為20類。
  • Segmentation:存放的是可用於分割的數據。

在圖像分割中,主要使用了SegmentationClass,SegmentationObject,JPEGImages有關的信息,VOC2012中的圖片並不是都用於分割,用於分割比賽的圖片實例如下,包含原圖以及圖像分類分割和圖像物體分割兩種png圖。圖像分類分割是在20種物體中,ground-turth圖片上每個物體的輪廓填充都有一個特定的顏色,一共20種顏色,比如摩托車用紅色表示,人用綠色表示。而圖像物體分割則僅僅在一副圖中生成不同物體的輪廓顏色即可,顏色自己隨便填充。

 

 

 

 ImageSets/Main/ 文件夾以 , {class}_trainval.txt {class}_val.txt 的格式命名。 train.txt val.txt 例外

aeroplane_train.txt
aeroplane_trainval.txt
aeroplane_val.txt
bicycle_train.txt
bicycle_trainval.txt
bicycle_val.txt
bird_train.txt
bird_trainval.txt
bird_val.txt
boat_train.txt
boat_trainval.txt
boat_val.txt
bottle_train.txt
bottle_trainval.txt
bottle_val.txt
bus_train.txt
bus_trainval.txt
bus_val.txt
car_train.txt
car_trainval.txt
car_val.txt
cat_train.txt
cat_trainval.txt
cat_val.txt
chair_train.txt
chair_trainval.txt
chair_val.txt
cow_train.txt
cow_trainval.txt
cow_val.txt
diningtable_train.txt
diningtable_trainval.txt
diningtable_val.txt
dog_train.txt
dog_trainval.txt
dog_val.txt
horse_train.txt
horse_trainval.txt
horse_val.txt
motorbike_train.txt
motorbike_trainval.txt
motorbike_val.txt
person_train.txt
person_trainval.txt
person_val.txt
pottedplant_train.txt
pottedplant_trainval.txt
pottedplant_val.txt
sheep_train.txt
sheep_trainval.txt
sheep_val.txt
sofa_train.txt
sofa_trainval.txt
sofa_val.txt
train.txt
train_train.txt
train_trainval.txt
train_val.txt
trainval.txt
tvmonitor_train.txt
tvmonitor_trainval.txt
tvmonitor_val.txt
val.txt

  • {class}_train.txt 保存類別為 class 的訓練集的所有索引,每一個 class 的 train 數據都有 5717 個。
  • {class}_val.txt 保存類別為 class 的驗證集的所有索引,每一個 class 的val數據都有 5823 個
  • {class}_trainval.txt 保存類別為 class 的訓練驗證集的所有索引,每一個 class 的val數據都有11540 個

每個文件包含內容為:

2011_003194 -1
2011_003216 -1
2011_003223 -1
2011_003230 1
2011_003236 1
2011_003238 1
2011_003246 1
2011_003247 0
2011_003253 -1
2011_003255 1
2011_003259 1
2011_003274 -1
2011_003276 -1

注:1代表正樣本,-1代表負樣本。

VOC2012/ImageSets/Main/train.txt 保存了所有訓練集的文件名,從 VOC2012/JPEGImages/ 找到文件名對應的圖片文件。VOC2012/Annotations/ 找到文件名對應的標簽文件

VOC2012/ImageSets/Main/val.txt 保存了所有驗證集的文件名,從 VOC2012/JPEGImages/ 找到文件名對應的圖片文件。VOC2012/Annotations/ 找到文件名對應的標簽文件

讀取 JPEGImages 和 Annotation 文件轉換為 tf 的 Example 對象,寫入 {train|test}{index}_of{num_shard} 文件。每個文件寫的 Example 的數量為 total_size/num_shard。(不同數據集可以適當調節 num_shard 來控制每個輸出文件的大小)

Annotations

文件夾中文件以 {id}.xml (id 保存在 VOC2012/ImageSets/Main/文件夾 ) 格式命名的 xml 文件,保存如下關鍵信息

  • 物體 label : name ,如下例子為 person
  • 圖片尺寸: depth, height, width
  • 物體 bbox : bndbox 下 xmax, xmin, ymax, ymin
<annotation>
	<folder>VOC2012</folder>
	<filename>2007_000032.jpg</filename>
	<source>
		<database>The VOC2007 Database</database>
		<annotation>PASCAL VOC2007</annotation>
		<image>flickr</image>
	</source>
	<size>
		<width>500</width>
		<height>281</height>
		<depth>3</depth>
	</size>
	<segmented>1</segmented>
	<object>
		<name>aeroplane</name>
		<pose>Frontal</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>104</xmin>
			<ymin>78</ymin>
			<xmax>375</xmax>
			<ymax>183</ymax>
		</bndbox>
	</object>
	<object>
		<name>aeroplane</name>
		<pose>Left</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>133</xmin>
			<ymin>88</ymin>
			<xmax>197</xmax>
			<ymax>123</ymax>
		</bndbox>
	</object>
	<object>
		<name>person</name>
		<pose>Rear</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>195</xmin>
			<ymin>180</ymin>
			<xmax>213</xmax>
			<ymax>229</ymax>
		</bndbox>
	</object>
	<object>
		<name>person</name>
		<pose>Rear</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>26</xmin>
			<ymin>189</ymin>
			<xmax>44</xmax>
			<ymax>238</ymax>
		</bndbox>
	</object>
</annotation>

tfrecord格式簡介

tfrecord是Tensorflow官方推薦的一種較為高效的數據讀取方式。使用Tensorflow訓練神經網絡時,讀取的數據方式有很多種。如果數據集比較小,而且內存足夠大,可以選擇直接將所有數據讀進內存,然后每次取一個batch的數據出來。如果數據較多,可以每次直接從硬盤中進行讀取,不過這種方式的讀取效率就比較低了。
tfrecord其實是一種數據存儲形式。使用tfrecord時,實際上是先讀取原生數據,然后轉換成tfrecord格式,再存儲在硬盤上。而使用時,再把數據從相應的tfrecord文件中解碼讀取出來。

Tensorflow有和tfrecord配套的一些函數,可以加快數據的處理。實際讀取tfrecord數據時,先以相應的tfrecord文件為參數,創建一個輸入隊列,這個隊列有一定的容量,用戶可以設置不同的值,在一部分數據出隊列時,tfrecord中的其他數據就可以通過預取進入隊列,並且這個過程和網絡的計算是獨立進行的。也就是說,網絡每一個iteration的訓練不必等待數據隊列准備好再開始,隊列中的數據始終是充足的,而往隊列中填充數據時,也可以使用多線程加速。

tfecord文件中的數據是通過tf.train.Example Protocol Buffer的格式存儲的,下面是tf.train.Example的定義。

message Example {
  Features features = 1;
};

message Features{
  map<string,Feature> featrue = 1;
};

message Feature{
  oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

tf.train.Example中包含了屬性名稱到取值的字典,其中屬性名稱為字符串,屬性的取值可以為字符串(BytesList)、實數列表(FloatList)或者整數列表(Int64List)。

將數據保存為tfrecord格式

首先,創建以tfrecord為后綴的文件名

tfrecords_filename = './tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename) # 創建.tfrecord文件,准備寫入

然后創建一個循環一次寫入數據

    for i in range(100):
        img_raw = np.random.random_integers(0,255,size=(7,30)) # 創建7*30,取值在0-255之間隨機數組
        img_raw = img_raw.tostring()
        example = tf.train.Example(features=tf.train.Features(
                feature={
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))
        writer.write(example.SerializeToString()) 
    
    writer.close()

example = tf.train.Example()這句將數據賦給了變量example(可以看到里面是通過字典結構實現的賦值),然后用writer.write(example.SerializeToString()) 這句實現寫入。

值得注意的是賦值給example的數據格式。從前面tf.train.Example的定義可知,tfrecord支持整型、浮點數和二進制三種格式,分別是

tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))

例如圖片等數組形式(array)的數據,可以保存為numpy array的格式,轉換為string,然后保存到二進制格式的feature中。對於單個的數值(scalar),可以直接賦值。這里value=[×]的[]非常重要,也就是說輸入的必須是列表(list)。當然,對於輸入數據是向量形式的,可以根據數據類型(float還是int)分別保存。並且在保存的時候還可以指定數據的維數。

讀取tfrecord數據

tf.parse_single_example解碼,tf.TFRecordReader讀取,一般,為了高效的讀取數據,tf中使用隊列讀取數據

def read_and_decode(filename):
    # 生成一個文件名的隊列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()  # 定義一個reader
    _, serialized_example = reader.read(filename_queue)   # 讀取文件名和example

    # 還原feature, 和制作tfrecords時一樣
    feature = { 'label': tf.FixedLenFeature([], tf.int64),  # 對於單個元素的變量,我們使用FixlenFeature來讀取,需要指明變量存儲的數據類型;對於list類型的變量,我們使用VarLenFeature來讀取,同樣需要指明讀取變量的類型
                'img_raw' : tf.FixedLenFeature([], tf.string), }
    # 使用tf.parse_single_example來解析example
    features = tf.parse_single_example(serialized_example, features=feature)

    # 對於圖像,使用tf.decode_raw解析對應的features,指定類型,然后reshape等
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)

    return img, label

img, label = read_and_decode('train.tfrecords')
# 在訓練時使用shuffle_batch隨機打亂順序,並生成batch
img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=30, 
                                                capacity=2000,  # 隊列的最大容量
                                                num_threads=1,  # 進行隊列操作的線程數
                                                min_after_dequeue=1000) # dequeue后最小的隊列大小,used to ensure a level of mixing of elements.

# tf隊列也需要初始化在sess中才能執行                      
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
with tf.Session() as sess:
    sess.run(init_op)

    coord = tf.train.Coordinator()  # 創建一個coordinate,用於協調各線程
    threads = tf.train.start_queue_runners(coord=coord)  # 使用QueueRunner對象來提取數據

    try:  # 推薦代碼
        while not coord.should_stop():
            # Run training steps or whatever
            sess.run(train_op)
    except tf.errors.OutOfRangeError:
        print 'Done training -- epoch limit reached'
    finally:
        # When done, ask the threads to stop.關閉線程
        coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)

  

以目標檢測所使用的文件為例,制作tfrecord文件代碼如下:

# coding=utf-8
import os
import sys
import random

import numpy as np
import tensorflow as tf
# process a xml file
import xml.etree.ElementTree as ET

DIRECTORY_ANNOTATIONS = 'Annotations/'
DIRECTORY_IMAGES = 'JPEGImages/'
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 20000

VOC_LABELS = {
    'none': (0, 'Background'),
    'aeroplane': (1, 'Vehicle'),
    'bicycle': (2, 'Vehicle'),
    'bird': (3, 'Animal'),
    'boat': (4, 'Vehicle'),
    'bottle': (5, 'Indoor'),
    'bus': (6, 'Vehicle'),
    'car': (7, 'Vehicle'),
    'cat': (8, 'Animal'),
    'chair': (9, 'Indoor'),
    'cow': (10, 'Animal'),
    'diningtable': (11, 'Indoor'),
    'dog': (12, 'Animal'),
    'horse': (13, 'Animal'),
    'motorbike': (14, 'Vehicle'),
    'person': (15, 'Person'),
    'pottedplant': (16, 'Indoor'),
    'sheep': (17, 'Animal'),
    'sofa': (18, 'Indoor'),
    'train': (19, 'Vehicle'),
    'tvmonitor': (20, 'Indoor'),
}


#返回一個int64_list
def int64_feature(values):
    """Returns a TF-Feature of int64s.
    Args:
    values: A scalar or list of values.
    Returns:
    a TF-Feature.
    """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

#返回float_list
def float_feature(value):
    """Wrapper for inserting float features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
#返回bytes_list
def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

#split的三種類型
SPLIT_MAP = ['train', 'val', 'trainval']

"""
Process a image and annotation file.
Args:
    filename:       string, path to an image file e.g., '/path/to/example.JPG'.
    coder:          instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
    image_buffer:   string, JPEG encoding of RGB image.
    height:         integer, image height in pixels.
    width:          integer, image width in pixels.
讀取一個樣本圖片及對應信息
directory:圖片所在路徑,name:圖片名稱
"""
def _process_image(directory, name):
    # Read the image file.
    filename = os.path.join(directory, DIRECTORY_IMAGES, name + '.jpg')
    image_data = tf.gfile.FastGFile(filename, 'rb').read()  #使用gfile讀取圖片
    # Read the XML annotation file.
    filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
    tree = ET.parse(filename)   #XML文檔表示為樹,ElementTree
    root = tree.getroot()       #樹的根節點
    # Image shape.
    size = root.find('size')
    shape = [int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text)]
    # Find annotations.
    # 獲取每個object的信息
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))
        labels_text.append(label.encode('ascii'))

        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
    return image_data, shape, bboxes, labels, labels_text, difficult, truncated

"""
Build an Example proto for an image example.
Args:
  image_data: string, JPEG encoding of RGB image;
  labels: list of integers, identifier for the ground truth;
  labels_text: list of strings, human-readable labels;
  bboxes: list of bounding boxes; each box is a list of integers;
      specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
      to the same label as the image label.
  shape: 3 integers, image shapes in pixels.
Returns:
  Example proto
將一個圖片及對應信息按格式轉換成訓練時可讀取的一個樣本
"""
def _convert_to_example(image_data, labels, labels_text, bboxes, shape, difficult, truncated):
    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-assigned

    image_format = b'JPEG'
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': int64_feature(shape[0]),
        'image/width': int64_feature(shape[1]),
        'image/channels': int64_feature(shape[2]),
        'image/shape': int64_feature(shape),
        'image/object/bbox/xmin': float_feature(xmin),
        'image/object/bbox/xmax': float_feature(xmax),
        'image/object/bbox/ymin': float_feature(ymin),
        'image/object/bbox/ymax': float_feature(ymax),
        'image/object/bbox/label': int64_feature(labels),
        'image/object/bbox/label_text': bytes_feature(labels_text),
        'image/object/bbox/difficult': int64_feature(difficult),
        'image/object/bbox/truncated': int64_feature(truncated),
        'image/format': bytes_feature(image_format),
        'image/encoded': bytes_feature(image_data)}))
    return example


"""
Loads data from image and annotations files and add them to a TFRecord.
Args:
  dataset_dir: Dataset directory;
  name: Image name to add to the TFRecord;
  tfrecord_writer: The TFRecord writer to use for writing.
"""
def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
    image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
        _process_image(dataset_dir, name)
    example = _convert_to_example(image_data,
                                  labels,
                                  labels_text,
                                  bboxes,
                                  shape,
                                  difficult,
                                  truncated)
    tfrecord_writer.write(example.SerializeToString())


"""
以VOC2012為例,下載后的文件名為:VOCtrainval_11-May-2012.tar,解壓后
得到一個文件夾:VOCdevkit
voc_root就是VOCdevkit文件夾所在的路徑
在VOCdevkit文件夾下只有一個文件夾:VOC2012,所以下邊參數year該文件夾的數字部分。
在VOCdevkit/VOC2012/ImageSets/Main下存放了20個類別,每個類別有3個的txt文件:
*.train.txt存放訓練使用的數據
*.val.txt存放測試使用的數據
*.trainval.txt是train和val的合集
所以參數split只能為'train', 'val', 'trainval'之一
"""
def run(voc_root, year, split, output_dir, shuffling=False):
    # 如果output_dir不存在則創建
    if not tf.gfile.Exists(output_dir):
        tf.gfile.MakeDirs(output_dir)
    # VOCdevkit/VOC2012/ImageSets/Main/train.txt
    # 中存放有所有20個類別的訓練樣本名稱,共5717個
    split_file_path = os.path.join(voc_root, 'VOC%s' % year, 'ImageSets', 'Main', '%s.txt' % split)
    print('>> ', split_file_path)
    with open(split_file_path) as f:
        filenames = f.readlines()
    # shuffling == Ture時,打亂順序
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)
    # Process dataset files.
    i = 0
    fidx = 0
    dataset_dir = os.path.join(voc_root, 'VOC%s' % year)
    while i < len(filenames):
        # Open new TFRecord file.
        tf_filename = '%s/%s_%03d.tfrecord' % (output_dir, split, fidx)
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES:
                sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                sys.stdout.flush()
                filename = filenames[i].strip()
                _add_to_tfrecord(dataset_dir, filename, tfrecord_writer)
                i += 1
                j += 1
            fidx += 1
    print('\n>> Finished converting the Pascal VOC dataset!')

if __name__ == '__main__':
    # if len(sys.argv) < 2:
    #     raise ValueError('>> error. format: python *.py split_name')
    split = 'train'     #'train|val|trainval'
    if split not in SPLIT_MAP:
        raise ValueError('>> error. split = %s' % split)
    voc_root = 'E:/data/VOCdevkit/'
    run(voc_root, 2012, split,voc_root)

以圖像分割使用文件為例,轉換代碼如下:

# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
 
"""Converts PASCAL VOC 2012 data to TFRecord file format with Example protos.
PASCAL VOC 2012 dataset is expected to have the following directory structure:
  + pascal_voc_seg
    - build_data.py
    - build_voc2012_data.py (current working directory).
    + VOCdevkit
      + VOC2012
        + JPEGImages
        + SegmentationClass
        + ImageSets
          + Segmentation
    + tfrecord
Image folder:
  ./VOCdevkit/VOC2012/JPEGImages
Semantic segmentation annotations:
  ./VOCdevkit/VOC2012/SegmentationClass
list folder:
  ./VOCdevkit/VOC2012/ImageSets/Segmentation
This script converts data into sharded data files and save at tfrecord folder.
The Example proto contains the following fields:
  image/encoded: encoded image content.
  image/filename: image filename.
  image/format: image file format.
  image/height: image height.
  image/width: image width.
  image/channels: image channels.
  image/segmentation/class/encoded: encoded semantic segmentation content.
  image/segmentation/class/format: semantic segmentation file format.
"""
import math
import os.path
import sys
import build_data
import tensorflow as tf
 
FLAGS = tf.app.flags.FLAGS
 
tf.app.flags.DEFINE_string('image_folder',
                           './pascal_voc_seg/VOCdevkit/VOC2012/JPEGImages',
                           'Folder containing images.')
 
tf.app.flags.DEFINE_string(
    'semantic_segmentation_folder',
    './pascal_voc_seg/VOCdevkit/VOC2012/SegmentationClassRaw',
    'Folder containing semantic segmentation annotations.')
#train.txt,val.txt,trainval.txt
tf.app.flags.DEFINE_string(
    'list_folder',
    './pascal_voc_seg/VOCdevkit/VOC2012/ImageSets/Segmentation',
    'Folder containing lists for training and validation')
 
#tfrecord輸出路徑
tf.app.flags.DEFINE_string(
    'output_dir',
    './pascal_voc_seg/tfrecord',
    'Path to save converted SSTable of TensorFlow examples.')
 
_NUM_SHARDS = 4
 
 
def _convert_dataset(dataset_split):
    """Converts the specified dataset split to TFRecord format.
    Args:
      dataset_split: The dataset split (e.g., train, test).
    Raises:
      RuntimeError: If loaded image and label have different shape.
    """
    dataset = os.path.basename(dataset_split)[:-4]
    sys.stdout.write('Processing ' + dataset)
    filenames = [x.strip('\n') for x in open(dataset_split, 'r')]
    num_images = len(filenames)
    num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS)))
 
    image_reader = build_data.ImageReader('jpg', channels=3)
    label_reader = build_data.ImageReader('png', channels=1)
 
    for shard_id in range(_NUM_SHARDS):
        output_filename = os.path.join(
            FLAGS.output_dir,
            '%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, _NUM_SHARDS))
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            start_idx = shard_id * num_per_shard
            end_idx = min((shard_id + 1) * num_per_shard, num_images)
            for i in range(start_idx, end_idx):
                sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                    i + 1, len(filenames), shard_id))
                sys.stdout.flush()
                # Read the image.
                image_filename = os.path.join(
                    FLAGS.image_folder, filenames[i] + '.jpg' )#+ FLAGS.image_format)
                image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
                height, width = image_reader.read_image_dims(image_data)
                # Read the semantic segmentation annotation.
                seg_filename = os.path.join(
                    FLAGS.semantic_segmentation_folder,
                    filenames[i] + '.' + FLAGS.label_format)
                seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
                seg_height, seg_width = label_reader.read_image_dims(seg_data)
                if height != seg_height or width != seg_width:
                    raise RuntimeError('Shape mismatched between image and label.')
                # Convert to tf example.
                example = build_data.image_seg_to_tfexample(
                    image_data, filenames[i], height, width, seg_data)
                tfrecord_writer.write(example.SerializeToString())
        sys.stdout.write('\n')
        sys.stdout.flush()
 
 
def main(unused_argv):
    dataset_splits = tf.gfile.Glob(os.path.join(FLAGS.list_folder, '*.txt'))
    for dataset_split in dataset_splits:
        _convert_dataset(dataset_split)
 
 
if __name__ == '__main__':
    tf.app.run()

參考鏈接一

參考鏈接二

參考鏈接三


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM