Tensorflow讀取CIFAR-10數據集
覺得有用的話,歡迎一起討論相互學習~
參考文獻
Tensorflow官方文檔
tf.transpose函數解析
tf.slice函數解析
CIFAR10/CIFAR100數據集介紹
tf.train.shuffle_batch函數解析
Python urllib urlretrieve函數解析
import os
import tarfile
import tensorflow as tf
from six.moves import urllib
from tensorflow.python.framework import ops
ops.reset_default_graph()
# 更改工作目錄
abspath = os.path.abspath(__file__) # 獲取當前文件絕對地址
# E:\GitHub\TF_Cookbook\08_Convolutional_Neural_Networks\03_CNN_CIFAR10\ostest.py
dname = os.path.dirname(abspath) # 獲取文件所在文件夾地址
# E:\GitHub\TF_Cookbook\08_Convolutional_Neural_Networks\03_CNN_CIFAR10
os.chdir(dname) # 轉換目錄文件夾到上層
# Start a graph session
# 初始化Session
sess = tf.Session()
# 設置模型超參數
batch_size = 128 # 批處理數量
data_dir = 'temp' # 數據目錄
output_every = 50 # 輸出訓練loss值
generations = 20000 # 迭代次數
eval_every = 500 # 輸出測試loss值
image_height = 32 # 圖片高度
image_width = 32 # 圖片寬度
crop_height = 24 # 裁剪后圖片高度
crop_width = 24 # 裁剪后圖片寬度
num_channels = 3 # 圖片通道數
num_targets = 10 # 標簽數
extract_folder = 'cifar-10-batches-bin'
# 指數學習速率衰減參數
learning_rate = 0.1 # 學習率
lr_decay = 0.1 # 學習率衰減速度
num_gens_to_wait = 250. # 學習率更新周期
# 提取模型參數
image_vec_length = image_height*image_width*num_channels # 將圖片轉化成向量所需大小
record_length = 1 + image_vec_length # ( + 1 for the 0-9 label)
# 讀取數據
data_dir = 'temp'
if not os.path.exists(data_dir): # 當前目錄下是否存在temp文件夾
os.makedirs(data_dir) # 如果當前文件目錄下不存在這個文件夾,創建一個temp文件夾
# 設定CIFAR10下載路徑
cifar10_url = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
# 檢查這個文件是否存在,如果不存在下載這個文件
data_file = os.path.join(data_dir, 'cifar-10-binary.tar.gz')
# temp\cifar-10-binary.tar.gz
if os.path.isfile(data_file):
pass
else:
# 回調函數,當連接上服務器、以及相應的數據塊傳輸完畢時會觸發該回調,我們可以利用這個回調函數來顯示當前的下載進度。
# block_num已經下載的數據塊數目,block_size數據塊大小,total_size下載文件總大小
def progress(block_num, block_size, total_size):
progress_info = [cifar10_url, float(block_num*block_size)/float(total_size)*100.0]
print('\r Downloading {} - {:.2f}%'.format(*progress_info), end="")
# urlretrieve(url, filename=None, reporthook=None, data=None)
# 參數 finename 指定了保存本地路徑(如果參數未指定,urllib會生成一個臨時文件保存數據。)
# 參數 reporthook 是一個回調函數,當連接上服務器、以及相應的數據塊傳輸完畢時會觸發該回調,我們可以利用這個回調函數來顯示當前的下載進度。
# 參數 data 指 post 到服務器的數據,該方法返回一個包含兩個元素的(filename, headers)元組,filename 表示保存到本地的路徑,header 表示服務器的響應頭。
# 此處 url=cifar10_url,filename=data_file,reporthook=progress
filepath, _ = urllib.request.urlretrieve(cifar10_url, data_file, progress)
# 解壓文件
tarfile.open(filepath, 'r:gz').extractall(data_dir)
# Define CIFAR reader
# 定義CIFAR讀取器
def read_cifar_files(filename_queue, distort_images=True):
reader = tf.FixedLengthRecordReader(record_bytes=record_length)
# 返回固定長度的文件記錄 record_length函數參數為一條圖片信息即1+32*32*3
key, record_string = reader.read(filename_queue)
# 此處調用tf.FixedLengthRecordReader.read函數返回鍵值對
record_bytes = tf.decode_raw(record_string, tf.uint8)
# 讀出來的原始文件是string類型,此處我們需要用decode_raw函數將String類型轉換成uint8類型
image_label = tf.cast(tf.slice(record_bytes, [0], [1]), tf.int32)
# 見slice函數用法,取從0號索引開始的第一個元素。並將其轉化為int32型數據。其中存儲的是圖片的標簽
# 截取圖像
image_extracted = tf.reshape(tf.slice(record_bytes, [1], [image_vec_length]),
[num_channels, image_height, image_width])
# 從1號索引開始提取圖片信息。這和此數據集存儲圖片信息的格式相關。
# CIFAR-10數據集中
"""第一個字節是第一個圖像的標簽,它是一個0-9范圍內的數字。接下來的3072個字節是圖像像素的值。
前1024個字節是紅色通道值,下1024個綠色,最后1024個藍色。值以行優先順序存儲,因此前32個字節是圖像第一行的紅色通道值。
每個文件都包含10000個這樣的3073字節的“行”圖像,但沒有任何分隔行的限制。因此每個文件應該完全是30730000字節長。"""
# Reshape image
image_uint8image = tf.transpose(image_extracted, [1, 2, 0])
# 詳見tf.transpose函數,將[channel,image_height,image_width]轉化為[image_height,image_width,channel]的數據格式。
reshaped_image = tf.cast(image_uint8image, tf.float32)
# 將圖片剪裁或填充至合適大小
final_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, crop_width, crop_height)
if distort_images:
# 將圖像水平隨機翻轉,改變亮度和對比度。
final_image = tf.image.random_flip_left_right(final_image)
final_image = tf.image.random_brightness(final_image, max_delta=63)
final_image = tf.image.random_contrast(final_image, lower=0.2, upper=1.8)
# 對圖片做標准化處理
"""Linearly scales `image` to have zero mean and unit norm.
This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
of all values in image, and `adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))`.
`stddev` is the standard deviation of all values in `image`.
It is capped away from zero to protect against division by 0 when handling uniform images."""
final_image = tf.image.per_image_standardization(final_image)
return (final_image, image_label)
# Create a CIFAR image pipeline from reader
# 從閱讀器中構造CIFAR圖片管道
def input_pipeline(batch_size, train_logical=False):
# train_logical標志用於區分讀取訓練和測試數據集
if train_logical:
files = [os.path.join(data_dir, extract_folder, 'data_batch_{}.bin'.format(i)) for i in range(1, 6)]
# data_dir=tmp
# extract_folder=cifar-10-batches-bin
else:
files = [os.path.join(data_dir, extract_folder, 'test_batch.bin')]
filename_queue = tf.train.string_input_producer(files)
image, label = read_cifar_files(filename_queue)
print(train_logical, 'after read_cifar_files ops image', sess.run(tf.shape(image)))
print(train_logical, 'after read_cifar_files ops label', sess.run(tf.shape(label)))
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 5000
capacity = min_after_dequeue + 3*batch_size
# 批量讀取圖片數據
example_batch, label_batch = tf.train.shuffle_batch([image, label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
print(train_logical, 'after shuffle_batch ops image', sess.run(tf.shape(image)))
print(train_logical, 'after shuffle_batch ops example_batch', sess.run(tf.shape(example_batch)))
print(train_logical, 'after shuffle_batch ops label', sess.run(tf.shape(label)))
print(train_logical, 'after shuffle_batch ops label_batch', sess.run(tf.shape(label_batch)))
return (example_batch, label_batch)
# 獲取數據
print('Getting/Transforming Data.')
# 初始化數據管道獲取訓練數據和對應標簽
images, targets = input_pipeline(batch_size, train_logical=True)
# 獲取測試數據和對應標簽
test_images, test_targets = input_pipeline(batch_size, train_logical=False)
sess.close()
# True after read_cifar_files ops image [24 24 3]
# True after read_cifar_files ops label [1]
# True after shuffle_batch ops image [24 24 3]
# True after shuffle_batch ops example_batch [128 24 24 3]
# True after shuffle_batch ops label [1]
# True after shuffle_batch ops label_batch [128 1]
# False after read_cifar_files ops image [24 24 3]
# False after read_cifar_files ops label [1]
# False after shuffle_batch ops image [24 24 3]
# False after shuffle_batch ops example_batch [128 24 24 3]
# False after shuffle_batch ops label [1]
# False after shuffle_batch ops label_batch [128 1]