利用Python 脚本生成 .h5 文件 代码


利用Python 脚本生成 .h5 文件 

 1 import os, json, argparse  2 from threading import Thread  3 from Queue import Queue  4 
 5 import numpy as np  6 from scipy.misc import imread, imresize  7 import h5py  8 
 9 """
 10 Create an HDF5 file of images for training a feedforward style transfer model.  11 """
 12 
 13 parser = argparse.ArgumentParser()  14 parser.add_argument('--train_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/train2014')  15 parser.add_argument('--val_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/val2014')  16 parser.add_argument('--output_file', default='/media/wangxiao/WangXiao_Dataset/CoCo/coco-256.h5')  17 parser.add_argument('--height', type=int, default=256)  18 parser.add_argument('--width', type=int, default=256)  19 parser.add_argument('--max_images', type=int, default=-1)  20 parser.add_argument('--num_workers', type=int, default=2)  21 parser.add_argument('--include_val', type=int, default=1)  22 parser.add_argument('--max_resize', default=16, type=int)  23 args = parser.parse_args()  24 
 25 
 26 def add_data(h5_file, image_dir, prefix, args):  27   # Make a list of all images in the source directory
 28   image_list = []  29   image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG', '.png', '.PNG'}  30   for filename in os.listdir(image_dir):  31     ext = os.path.splitext(filename)[1]  32     if ext in image_extensions:  33  image_list.append(os.path.join(image_dir, filename))  34   num_images = len(image_list)  35 
 36   # Resize all images and copy them into the hdf5 file
 37   # We'll bravely try multithreading
 38   dset_name = os.path.join(prefix, 'images')  39   dset_size = (num_images, 3, args.height, args.width)  40   imgs_dset = h5_file.create_dataset(dset_name, dset_size, np.uint8)  41   
 42   # input_queue stores (idx, filename) tuples,
 43   # output_queue stores (idx, resized_img) tuples
 44   input_queue = Queue()  45   output_queue = Queue()  46   
 47   # Read workers pull images off disk and resize them
 48   def read_worker():  49     while True:  50       idx, filename = input_queue.get()  51       img = imread(filename)  52       try:  53         # First crop the image so its size is a multiple of max_resize
 54         H, W = img.shape[0], img.shape[1]  55         H_crop = H - H % args.max_resize  56         W_crop = W - W % args.max_resize  57         img = img[:H_crop, :W_crop]  58         img = imresize(img, (args.height, args.width))  59       except (ValueError, IndexError) as e:  60         print filename  61         print img.shape, img.dtype  62         print e  63  input_queue.task_done()  64  output_queue.put((idx, img))  65   
 66   # Write workers write resized images to the hdf5 file
 67   def write_worker():  68     num_written = 0  69     while True:  70       idx, img = output_queue.get()  71       if img.ndim == 3:  72         # RGB image, transpose from H x W x C to C x H x W
 73         imgs_dset[idx] = img.transpose(2, 0, 1)  74       elif img.ndim == 2:  75         # Grayscale image; it is H x W so broadcasting to C x H x W will just copy
 76         # grayscale values into all channels.
 77         imgs_dset[idx] = img  78  output_queue.task_done()  79       num_written = num_written + 1
 80       if num_written % 100 == 0:  81         print 'Copied %d / %d images' % (num_written, num_images)  82   
 83   # Start the read workers.
 84   for i in xrange(args.num_workers):  85     t = Thread(target=read_worker)  86     t.daemon = True  87  t.start()  88     
 89   # h5py locks internally, so we can only use a single write worker =(
 90   t = Thread(target=write_worker)  91   t.daemon = True  92  t.start()  93     
 94   for idx, filename in enumerate(image_list):  95     if args.max_images > 0 and idx >= args.max_images: break
 96  input_queue.put((idx, filename))  97     
 98  input_queue.join()  99  output_queue.join() 100   
101   
102   
103 if __name__ == '__main__': 104   
105   with h5py.File(args.output_file, 'w') as f: 106     add_data(f, args.train_dir, 'train2014', args) 107 
108     if args.include_val != 0: 109       add_data(f, args.val_dir, 'val2014', args)

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM