tensorflow ----tutorials之cifar10的卷積神經網絡代碼閱讀之看代碼就是剝洋蔥,邊剝邊流淚之step by step(1)!


      決定寫tensorflow之cifar10的卷積神經網絡代碼閱讀的文章,因為我自己靜不下心來閱讀,所以寫文章不會讓我貪快閱讀從而沒有思考和中斷了可以接上!!!

既然是為了自己,所以就按照自己思路啦,有給他人帶來煩惱,請見諒。恩,思路是從 python cifar10_train.py這個指令開始,到整個訓練,后期可能會給

出數據流向介紹(一般說下次給,就代表不會給了,應該是套路吧)

 

我們開始吧,首先看代碼清單:

 ------------------------cifar10.py

 ------------------------cifar10_eval.py

 ------------------------cifar10_input.py

 ------------------------cifar10_input_test.py

 ------------------------cifar10_train.py

 ------------------------cifar10_multi_gpu_train.py

     官網教程說執行: python cifar10_train.py,這個指令后,你就可以訓練了。來!!!!!!!我們來看下這個cifar10_train.py

執行這一句話,第一執行的代碼是:

if __name__ == '__main__':
tf.app.run()


然后,就跳到main函數啦:
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()

main函數第一句話:
 cifar10.maybe_download_and_extract()
調用了cifar10的一個函數,我們來看這個函數cifar10.py:
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = FLAGS.data_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
if not os.path.exists(extracted_dir_path):
tarfile.open(filepath, 'r:gz').extractall(dest_directory)

這個函數沒什么特別的,沒有代碼閱讀困難吧(對我自己啦,這里沒有炫耀的成分,不要誤解,牛逼的人都知道我真正要表達的是什么意思,下同),
主要功能是: 查看數據,數據在不在,如果在那就ok,不在就網上下載並解壓。

: 看FLAGS.train_dir 文件夾在不在,在則刪掉后創建,不在則創建
估計是運行的log吧,重新運行了當然要把老的log先干掉要,在cifar10_train.py,找到了定義,就是./cifar10_train文件夾
tf.app.flags.DEFINE_string('train_dir', './cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")



回到main函數,發現只剩下train()函數啦:
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()

# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
# GPU and resulting in a slow down.
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs() #艱難啊!!!!!!!!!!!看代碼就是剝洋蔥之數據准備

# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images) #嗯,重頭戲之看代碼就是剝洋蔥之網絡構建

# Calculate loss.
loss = cifar10.loss(logits, labels) #嗯,重頭戲之損失函數構建

# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step) #嗯,重頭戲之訓練流程構建 (為什么構建?,因為在session中才運行的哦,哥哥!!!)

class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""

def begin(self):
self._step = -1
self._start_time = time.time()

def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.

def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time

loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)

format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))

with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)



粘貼下來才發現有點小多,加油增哥,一條條看,肯定可以搞定!!!!!!!!



tf.Graph():官網解釋:
A TensorFlow computation, represented as a dataflow graph
A Graph contains a set of tf.Operation objects, which represent units of computation; and tf.Tensor objects, which represent the units of data that flow between operations

也就是整張圖啦(不懂網上搜索下tensorflow的圖,下同),表示計算等的集合,再多說下

A default Graph is always registered, and accessible by calling tf.get_default_graph. To add an operation to the default graph, simply call one of the functions that defines a new Operation:

c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()


感覺和qt 的graphic framework 類似哦!!!

回來!!!!!!!!! with tf.Graph().as_default(): 定義了一張空白的圖紙,現在我們繼續走下去,准備在圖上畫畫啦!!!!下一條:
global_step = tf.contrib.framework.get_or_create_global_step()
我們在圖上話的第一步是global_step,
還是看不懂,去網上看下.....
顧名思義:

Returns and create (if necessary) the global step tensor.

 
        

Args:

 
        
  • graph: The graph in which to create the global step tensor. If missing, use default graph.
 
        

Returns:

 
        

The global step tensor.

現在問題來了,什么是global step tensor

網上說:

global_step: A scalar int32 or int64 Tensor or a Python number. Global step to use for the decay computation. Must not be negative.

它用於衰減之類的,就是全局計數吧,暫時這么理解。

好!!!!回到train()下一句是:

with tf.device('/cpu:0'):
這是說明在with限制的區域內采用CPU計算,其中,with限制區域有:
images, labels = cifar10.distorted_inputs() #好吧,只有一句話,就是這個函數
--------------------------------------------------------------------------------------------------------------------看代碼就是剝洋蔥之數據准備
走!!!去cifar10.py看代碼(等下記得回來):
def distorted_inputs():
"""Construct distorted input for CIFAR training using the Reader ops.

Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.

Raises:
ValueError: If no data_dir
"""
if not FLAGS.data_dir: #這句是查看數據文件夾有沒有,不講解啦
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=FLAGS.batch_size) #是吧,這句才是核心!!!!!!!!!!!!!!!!!!!!!!
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16) #先解釋完上句,核心語句再來收拾這個
labels = tf.cast(labels, tf.float16)
return images, labels

 走!!!去cifar10_input.py看代碼(等下記得回來,這下要玩剝洋蔥的游戲了)cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=FLAGS.batch_size)

 

def distorted_inputs(data_dir, batch_size):
"""Construct distorted input for CIFAR training using the Reader ops.

Args:
data_dir: Path to the CIFAR-10 data directory.
batch_size: Number of images per batch.

Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)] ###為什么是6!!!!,發現在/tmp/cifar10_data/cifar-10-batches-bin/ 下面果然有6個文件,分別是:data_batch_1.bin data_batch_2.bin data_batch_...data_batch_6.bin
for f in filenames: ###現在filenames是一個列表吧,包含六個文件的列表,這還不夠,還要一個一個去check存不存在。
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
##數據都存在,現在開始干活
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames) ####tf.train.string_input_producer這是啥,怎么不懂呀,恩。。。。,去網上看下.....https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer
#Output strings (e.g. filenames) to a queue for an input pipeline
#輸出一系列的字符串,比如文件名,到哪里去?到一個隊列中去(queue),干什么?給input pipeline 用,怎么用?繼續看下面咯


# Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue) #好吧,這里又有個重點要說
reshaped_image = tf.cast(read_input.uint8image, tf.float32) #咦,之前的問題,等當前解決完就可以啦。
嗯,看完read_cifar10函數來解決這啦,這是格式轉換用的cast

height = IMAGE_SIZE #圖像尺寸,在cifar10_input.py中定義IMAGE_SIZE = 24
width = IMAGE_SIZE

# Image processing for training the network. Note the many random
# distortions applied to the image.

# Randomly crop a [height, width] section of the image.
distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) ###這個不難,稍后解決。
恩,現在來解決啦,就是隨機將圖像修剪到我們需要的size[height, width, 3]https://www.tensorflow.org/api_docs/python/tf/random_crop

# Randomly flip the image horizontally. 隨意地水平翻轉圖像 Randomly flip an image horizontally (left to right) https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right
distorted_image = tf.image.random_flip_left_right(distorted_image) #不難

# Because these operations are not commutative, consider randomizing
# the order their operation.
# NOTE: since per_image_standardization zeros the mean and makes
# the stddev unit, this likely has no effect see tensorflow#1458.
distorted_image = tf.image.random_brightness(distorted_image, #Adjust the brightness of images by a random factor.https://www.tensorflow.org/api_docs/python/tf/image/random_brightness
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image, #Adjust the contrast of an image by a random factor https://www.tensorflow.org/api_docs/python/tf/image/random_contrast
lower=0.2, upper=1.8)

# Subtract off the mean and divide by the variance of the pixels. Subtract off the mean and divide by the variance of the pixels 減去平均值並除以像素的方差
float_image = tf.image.per_image_standardization(distorted_image) #

# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])

# Ensure that the random shuffling has good mixing properties. 確保隨機選取(洗牌)具有很好性能
min_fraction_of_examples_in_queue = 0.4 #隊列中每個樣本最小的分數
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * #在cifar10_input中定義NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)

# Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=True)

 好吧,不好用一兩句話的說的就要在后面解釋了,上面代碼read_input = read_cifar10(filename_queue),這個要解釋,就是讀取.bin數據文件啦。

走!!!去cifar10_input.py看代碼read_cifar10(filename_queue)

 

def read_cifar10(filename_queue):
"""Reads and parses examples from CIFAR10 data files.

Recommendation: if you want N-way read parallelism, call this function
N times. This will give you N independent Readers reading different
files & positions within those files, which will give better mixing of
examples.

Args:
filename_queue: A queue of strings with the filenames to read from.

Returns:#這個要仔細看下,這整個函數中定義了一個類,返回的也是這個類成員
An object representing a single example, with the following fields:
height: number of rows in the result (32)
width: number of columns in the result (32)
depth: number of color channels in the result (3)
key: a scalar string Tensor describing the filename & record number
for this example.
label: an int32 Tensor with the label in the range 0..9.
uint8image: a [height, width, depth] uint8 Tensor with the image data
"""

class CIFAR10Record(object):
pass #當你在編寫一個程序時,執行語句部分思路還沒有完成,這時你可以用pass語句來占位,也可以當做是一個標記,是要過后來完成的代碼。
result = CIFAR10Record()

# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
label_bytes = 1 # 2 for CIFAR-100 定義了標簽所占的大小
result.height = 32 #定義圖像高
result.width = 32 #定義圖像寬
result.depth = 3 #圖像是rgb所以深度為3
image_bytes = result.height * result.width * result.depth #計算一張圖所占字節,注意哦,這里沒包含標簽的大小哦
# Every record consists of a label followed by the image, with a 這里就介紹啦,標簽緊跟在圖像后面
# fixed number of bytes for each.
record_bytes = label_bytes + image_bytes #這里就計算每一次record 的大小啦

# Read a record, getting filenames from the filename_queue. No #從filename_queue中提取到的filenames中讀取record
# header or footer in the CIFAR-10 format, so we leave header_bytes #cifar10 數據是沒有幀頭和幀尾的,因此頭尾大小為0
# and footer_bytes at their default of 0.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) #FixedLengthRecordReader,簡言就是讀取固定長度的頭,這個函數看官網:https://www.tensorflow.org/api_docs/python/tf/FixedLengthRecordReader

result.key, value = reader.read(filename_queue)# reader 調用read函數,返回A tuple of Tensors (key, value). key: A string scalar Tensor. value: A string scalar Tensor.

# Convert from a string to a vector of uint8 that is record_bytes long.將督導的value從string中轉換成向量,這就涉及到 tf.decode_raw函數:https://www.tensorflow.org/api_docs/python/tf/decode_raw
record_bytes = tf.decode_raw(value, tf.uint8)
‘’‘
decode_raw(
    bytes, ######### bytes: A Tensor of type string. All the elements must have the same length
    out_type, ######## out_type: A tf.DType from: tf.half, tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.int64
    little_endian=None, ### little_endian: An optional bool. Defaults to True. Whether the input bytes are in little-endian order. Ignored for out_type values that are stored in a single byte like uint8
    name=None ####### name: A name for the operation (optional)
)
’‘’
# The first bytes represent the label, which we convert from uint8->int32.第一個字節代表標簽tf.strided_slice函數:https://www.tensorflow.org/api_docs/python/tf/strided_slice
#tf.cast函數:格式轉換//https://www.tensorflow.org/api_docs/python/tf/cast
result.label = tf.cast(
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
# tf.reshape:https://www.tensorflow.org/api_docs/python/tf/reshape
depth_major = tf.reshape(
tf.strided_slice(record_bytes, [label_bytes],
[label_bytes + image_bytes]),
[result.depth, result.height, result.width])
# Convert from [depth, height, width] to [height, width, depth]. tf.transpose:https://www.tensorflow.org/api_docs/python/tf/transpose
result.uint8image = tf.transpose(depth_major, [1, 2, 0])

return result
返回的只有一張image哦,這個封裝在類里面,至此,read_cifar10(filename_queue)結束,回到cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=FLAGS.batch_size)
中去啦,剝掉了一個函數咯,走!!!(然后回到上面繼續看distorted_inputs Îcifar10_input.distorted_inputscifar10_input.distorted_inputs:達代表::代表箭頭)

     哈哈,回到distorted_inputs中從read_cifar10中看完又回來啦,可惜文字看不出動畫哦,文字就直接下來,來來回回地看代碼過程體現不出來。只能用不同字體咯。

現在就差下面的_generate_image_and_label_batch函數啦。   來!!!!!!!!!!!!繼續:

 return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=True)
這里傳入參數有一個要補充的是:batch_size = tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
 
def _generate_image_and_label_batch(image, label, min_queue_examples,
batch_size, shuffle):
"""Construct a queued batch of images and labels.

Args:
image: 3-D Tensor of [height, width, 3] of type.float32.
label: 1-D Tensor of type.int32
min_queue_examples: int32, minimum number of samples to retain
in the queue that provides of batches of examples.
batch_size: Number of images per batch.
shuffle: boolean indicating whether to use a shuffling queue. shuffling
改組


Returns:
images: Images. 4D tensor of [batch_size, height, width, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
# Create a queue that shuffles the examples, and then
# read 'batch_size' images + labels from the example queue.
num_preprocess_threads = 16
if shuffle:
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, label_batch = tf.train.batch( #tf.train.batch 這個函數很重要哦 https://www.tensorflow.org/api_docs/python/tf/train/batch
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)

# Display the training images in the visualizer.
tf.summary.image('images', images)

return images, tf.reshape(label_batch, [batch_size])

好吧,是不是很快就吧這個看完了_generate_image_and_label_batch函數,這時候數據就准備好啦!!!!
走!!!干掉了cifar10_input.distorted_inputs,現在回到cifar10.distorted_inputs(cifar10.py)中(又剝掉(洋蔥)一個函數)發現,還剩下:
 if FLAGS.use_fp16:
images = tf.cast(images, tf.float16) #先解釋完上句,核心語句再來收拾這個
labels = tf.cast(labels, tf.float16)
return images, labels


好吧,這我就不說啦!!!!!!!!


走!!!干掉了cifar10.distored_inputs,回到 cifar10_train.pytrain()中啦,接下來是------

--------------------------------------------------------------------看代碼就是剝洋蔥之網絡構建看代碼就是剝洋蔥之網絡構建看代碼就是剝洋蔥之網絡
-------------------------------------------------------------------------------------------------------------------------------------------看代碼就是剝洋蔥之網絡構建

太多,可能有問題,所以分多個文章吧
 
        

 
        
 






 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM