一、使用urllib下載cifar-10數據集,並讀取再存為圖片(TensorFlow v1.14.0)
1 # -*- coding:utf-8 -*- 2 __author__ = 'Leo.Z' 3 4 import sys 5 import os 6 7 # 給定url下載文件 8 def download_from_url(url, dir=''): 9 _file_name = url.split('/')[-1] 10 _file_path = os.path.join(dir, _file_name) 11 12 # 打印下載進度 13 def _progress(count, block_size, total_size): 14 sys.stdout.write('\r>> Downloading %s %.1f%%' % 15 (_file_name, float(count * block_size) / float(total_size) * 100.0)) 16 sys.stdout.flush() 17 18 # 如果不存在dir,則創建文件夾 19 if not os.path.exists(dir): 20 print("Dir is not exsit,Create it..") 21 os.makedirs(dir) 22 23 if not os.path.exists(_file_path): 24 print("Start downloading..") 25 # 開始下載文件 26 import urllib 27 urllib.request.urlretrieve(url, _file_path, _progress) 28 else: 29 print("File already exists..") 30 31 return _file_path 32 33 # 使用tarfile解壓縮 34 def extract(filepath, dest_dir): 35 if os.path.exists(filepath) and not os.path.exists(dest_dir): 36 import tarfile 37 tarfile.open(filepath, 'r:gz').extractall(dest_dir) 38 39 40 if __name__ == '__main__': 41 FILE_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 42 FILE_DIR = 'cifar10_dir/' 43 44 loaded_file_path = download_from_url(FILE_URL, FILE_DIR) 45 extract(loaded_file_path)
按BATCH_SIZE讀取二進制文件中的圖片數據,並存放為jpg:
# -*- coding:utf-8 -*- __author__ = 'Leo.Z' # Tensorflow Version:1.14.0 import os import tensorflow as tf from PIL import Image BATCH_SIZE = 128 def read_cifar10(filenames): label_bytes = 1 height = 32 width = 32 depth = 3 image_bytes = height * width * depth record_bytes = label_bytes + image_bytes # lamda函數體 # def load_transform(x): # # Convert these examples to dense labels and processed images. # per_record = tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes]) # return per_record # tf v1.14.0版本的FixedLengthRecordDataset(filename_list,bin_data_len) datasets = tf.data.FixedLengthRecordDataset(filenames=filenames, record_bytes=record_bytes) # 是否打亂數據 # datasets.shuffle() # 重復幾輪epoches datasets = datasets.shuffle(buffer_size=BATCH_SIZE).repeat(2).batch(BATCH_SIZE) # 使用map,也可使用lamda(注意,后面使用迭代器的時候這里轉換為uint8沒用,后面還得轉一次,否則會報錯) # datasets.map(load_transform) # datasets.map(lamda x : tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])) # 創建一起迭代器tf v1.14.0 iter = tf.compat.v1.data.make_one_shot_iterator(datasets) # 獲取下一條數據(label+image的二進制數據1+32*32*3長度的bytes) rec = iter.get_next() # 這里轉uint8才生效,在map中轉貌似有問題? rec = tf.decode_raw(rec, tf.uint8) label = tf.cast(tf.slice(rec, [0, 0], [BATCH_SIZE, label_bytes]), tf.int32) # 從第二個字節開始獲取圖片二進制數據大小為32*32*3 depth_major = tf.reshape( tf.slice(rec, [0, label_bytes], [BATCH_SIZE, image_bytes]), [BATCH_SIZE, depth, height, width]) # 將維度變換順序,變為[H,W,C] image = tf.transpose(depth_major, [0, 2, 3, 1]) # 返回獲取到的label和image組成的元組 return (label, image) def get_data_from_files(data_dir): # filenames一共5個,從data_batch_1.bin到data_batch_5.bin # 讀入的都是訓練圖像 filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)] # 判斷文件是否存在 for f in filenames: if not tf.io.gfile.exists(f): raise ValueError('Failed to find file: ' + f) # 獲取一張圖片數據的數據,格式為(label,image) data_tuple = read_cifar10(filenames) return data_tuple if __name__ == "__main__": # 獲取label和type的對應關系 label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] name_list = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] label_map = dict(zip(label_list, name_list)) with tf.compat.v1.Session() as sess: batch_data = get_data_from_files('cifar10_dir/cifar-10-batches-bin') # 在之前的舊版本中,因為使用了filename_queue,所以要使用start_queue_runners進行數據填充 # 1.14.0由於沒有使用filename_queue所以不需要 # threads = tf.train.start_queue_runners(sess=sess) sess.run(tf.compat.v1.global_variables_initializer()) # 創建一個文件夾用於存放圖片 if not os.path.exists('cifar10_dir/raw'): os.mkdir('cifar10_dir/raw') # 存放30張,以index-typename.jpg命名,例如1-frog.jpg for i in range(30): # 獲取一個batch的數據,BATCH_SIZE # batch_data中包含一個batch的image和label batch_data_tuple = sess.run(batch_data) # 打印(128, 1) print(batch_data_tuple[0].shape) # 打印(128, 32, 32, 3) print(batch_data_tuple[1].shape) # 每個batch存放第一張圖片作為實驗 Image.fromarray(batch_data_tuple[1][0]).save("cifar10_dir/raw/{index}-{type}.jpg".format( index=i, type=label_map[batch_data_tuple[0][0][0]]))
簡要代碼流程圖: