tensorflow的卷積和池化層(二):記實踐之cifar10


tensorflow中的卷積和池化層(一)各種卷積類型Convolution這兩篇博客中,主要講解了卷積神經網絡的核心層,同時也結合當下流行的Caffe和tf框架做了介紹,本篇博客將接着tensorflow中的卷積和池化層(一)的內容,繼續介紹tf框架中卷積神經網絡CNN的使用。

因此,接下來將介紹CNN的入門級教程cifar10\100項目。cifar10\100 數據集是由Alex Krizhevsky、Vinod Nair和Geoffrey Hinton收集的,這兩個數據集都是從8000萬的數據集中挑選出來的。因此構成它們本身的圖片是很相似的,而區別在於:

  • cifar-10是由60000張表示10類物體的32*32大小的彩色圖片構成,顧名思義,每類剛好6000張,類間數據平衡,而且5000張用於訓練,1000張用於測試和驗證,那么這個數據集就總共有50000張訓練圖片,10000張測試圖片。那包含的10類如下:airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck。

  • cifar-100是由60000張表示100類物體的32*32大小的彩色圖片構成,顧名思義,每類剛好600張,類內數據平衡,而且500張用於訓練,100張用於測試和驗證,那么這個數據集就總共有50000張訓練圖片,10000張測試圖片。包含的類別可查看官網。

官網地址如下:The CIFAR-10 and CIFAR-100 dataset  里面給我們提供了3種讀取數據集的方式。

不管學習何種框架,cifar-10\100都是入門級的經典CNN項目,之所以稱之為項目,則是由上述數據集演化的依賴各種不同的工具解決的各類工程問題,此處就是10類或100類的分類問題,因此,這個項目值得大家共同學習。而本篇博客就介紹在tf工具下實現的cifar-10\100項目。

cifar-10的代碼結構如下表所示,總共有5個文件,它們的作用如下:

文件 作用
cifar10_input.py 讀取本地CIFAR-10的二進制文件格式的內容。
cifar10.py 建立CIFAR-10的模型。
cifar10_train.py 在CPU或GPU上訓練CIFAR-10的模型。
cifar10_multi_gpu_train.py 在多GPU上訓練CIFAR-10的模型。
cifar10_eval.py 評估CIFAR-10模型的預測性能。

給出上述5個文件的百度雲地址:cifar-10項目鏈接   密碼: f9ge

由於上述文件中有着詳細的注釋,因此下面對這些文件需要進一步理解的地方具體說明。

  • cifar10_input.py

這個文件是用來讀取官方cifar10\100數據集的,並且是三種數據集方式中的二進制文件格式,也就是一系列data_batch_num.bin(num=1...5)文件和test_batch.bin文件,每一個文件的組織形式都是這樣的:

<1 x label><3072 x pixel> ... <1 x label><3072 x pixel>

那也就是說,每一個batch中每一行都記錄着一張圖片,第一個字節是這個圖片的label,應該是在0-9范圍內的整形變量;而另外3072個字節則表示3個通道的像素值,即3*32*32,應該是按照RGB的順序排列着,即1024R+1024G+1024B。
除此之外,還有另外一個文件,batches.meta.txt,顧名思義,光有label不行,還需要知道每個label代表什么,這個文件里按行存放着每個整形label的表示,並且兩者是一一對應的。
這個文件有4個函數:read_cifar10、_generate_image_and_label_batch、distorted_inputs、inputs。
第一個函數read_cifar10的目的在於讀取一行data_batch_num.bin中的內容,即讀取一張圖片,並且獲取這張圖片的label和按照[height,width,channel]維度組織的像素值。
這里有4個重要的函數:tf.FixedLengthRecordReader、tf.slice、tf.reshape、tf.transpose。其中,
tf.FixedLengthRecordReader是專門用來讀取固定長度字節數的二進制文件閱讀器;tf.slice函數是tf的切片操作,函數原型如下:
tf.slice(inputs,begin,size,name='')
在begin的位置從inputs上抽取size大小的內容,name有默認值,例如:
tf.slice(record_bytes, [0], [label_bytes])
tf.slice(record_bytes, [label_bytes], [image_bytes])
tf.reshape就是將一個tensor的維度重組,函數原型如下:
tf.reshape(tensor,shape,name=None)

將原來的tensor按照shape的樣子重新組織成為新的tensor。例如:

tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), [result.depth, result.height, result.width])

這樣就可以把slice得到的那一串字符重新組織成[result.depth, result.height, result.width]大小和維度的tensor了。

tf.transpose就是將一個tensor的維度順序進行交換,函數原型如下:
tf.transpose(a, perm=None, name='transpose')

將tensor a按照perm的順序交換變成新的tensor,例如:

 result.uint8image = tf.transpose(depth_major, [1, 2, 0])

這就把[result.depth, result.height, result.width]變成了[result.height, result.width, result.depth]。

第二個函數_generate_image_and_label_batch的目的在於構建一個batch的圖片和相應的label。這里有一個非常重要的函數tf.train.shuffle_batch,例如:

images, label_batch = tf.train.shuffle_batch(
      [image, label],
      batch_size=batch_size,
      num_threads=num_preprocess_threads,
      capacity=min_queue_examples + 3 * batch_size,
      min_after_dequeue=min_queue_examples)

這樣理解:這個函數構建了一個capacity大小的隊列,然后呢,在capacity內隨機打亂這些圖片,每次從中取大小為batch_size的圖片數目出列,同時又添加一部分的圖片數目入列,整個過程中在隊列里的圖片數目不能少於min_after_dequeue個,以保證進出之間達到很好的打亂效果。就這樣,這個函數返回一個batch的image和對應的label。

第三個函數distorted_inputs和第四個函數inputs的目的分別是對訓練和測試的數據集進行數據增強和預處理,包括crop、Flip、random_brightness、random_contrast、Whitening等等。最后返回的就是經過這些處理后的數據作為模型真正的輸入。這里面的函數看看代碼即可。

  • cifar10.py

這個文件主要是用來定義cifar10\100模型的,同時定義loss函數,以便訓練。這個文件內共定義了10個函數,分別是_activation_summary、_variable_on_cpu、_variable_with_weight_decay、 distorted_inputs、inputs、inference、loss、_add_loss_summaries、train、maybe_download_and_extract。除此之外就是一些超參數的配置,比如batch_size=128,初始化學習率0.1,學習率的衰減因子0.1,學習速率開始下降的周期數350,移動平均衰減量0.9999等等。

函數的功能如下:

_activation_summary函數為激活函數添加summary,方便在tensorboard中可視化相關節點傳輸的數據。主要是tf.histogram_summary和tf.scalar_summary,這將在后續介紹。

_variable_on_cpu函數的目的是在CPU上創建變量。

_variable_with_weight_decay函數的目的是為了利用高斯分布初始化變量並且需要時添加權重衰減因子weight_decay。

distorted_inputs函數的目的是在cifar10_input.py的基礎上構建訓練數據集,得到經過數據增強和預處理之后模型的輸入數據。

inputs函數的目的是在cifar10_input.py的基礎上構建測試數據集,得到經過數據增強和預處理之后模型測試的輸入數據。

inference函數的目的是構建CNN網絡。

loss函數的目的是定義模型的loss。

 _add_loss_summaries函數的目的是給loss添加summary,以便於可視化。

train函數的目的是訓練cifar10模型,代價函數等。

maybe_download_and_extract函數的目的是從指定的網站下載cifar10數據集並解壓。

 運行上述訓練代碼cifar10_train.py,也不是一帆風順的,錯誤和修改辦法如下:

1. AttributeError: module 'tensorflow.python.ops.image_ops' has no attribute 'random_crop'。

這個錯誤來自於cifar10_input.py文件中的distorted_image = tf.image.random_crop(reshaped_image, [height, width]),將此句修改為:

distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

2. AttributeError: module 'tensorflow.python.ops.image_ops' has no attribute 'per_image_whitening'。

這個錯誤來自於cifar10_input.py文件中的 float_image = tf.image.per_image_whitening(distorted_image),將此句修改為:

 float_image = tf.image.per_image_standardization(distorted_image)

3. AttributeError: module 'tensorflow' has no attribute 'image_summary'。

這個錯誤來自於cifar10_input.py文件中的tf.image_summary('images', images),將此句修改為:

tf.summary.image('images', images)
 tf.summary.scalar('learning_rate', lr)#cifar10.py

注:整個項目類似的地方做修改。

4. AttributeError: module 'tensorflow' has no attribute 'histogram_summary'。

這個錯誤來自於cifar10.py文件中的 tf.histogram_summary(tensor_name + '/activations', x),將此句修改為:

tf.summary.histogram(tensor_name + '/activations', x)
tf.summary.histogram(var.op.name, var)

注:整個項目類似的地方做修改。

5. AttributeError: module 'tensorflow' has no attribute 'scalar_summary'。

這個錯誤來自於cifar10.py文件中的 tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x)),將此句修改為:

tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))

6. AttributeError: module 'tensorflow' has no attribute 'mul'。

這個錯誤之前有說過,改為multiply即可。

7. ValueError: Tried to convert 'tensor' to a tensor and failed. Error: Argument must be a dense tensor: range(0, 128) - got shape [128], but wanted []。

這個錯誤定位在cifar10.py的這句代碼上:

 indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1])

將此句改為:

 indices = tf.reshape(list(range(FLAGS.batch_size)), [FLAGS.batch_size, 1])

8. ValueError: Shapes (2, 128, 1) and () are incompatible。

這個錯誤是cifar10.py中的 concated = tf.concat(1, [indices, sparse_labels])觸發的,此句修改為:

 concated = tf.concat([indices, sparse_labels], 1)

9. ValueError: Only call `softmax_cross_entropy_with_logits` with named arguments (labels=..., logits=..., ...)。

這個錯誤來自於softmax_cross_entropy_with_logits這個函數,新的tf版本更新了這個函數,函數原型:

tf.nn.softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, dim=-1, name=None)

將其修改為:

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
      logits=logits, labels=dense_labels, name='cross_entropy_per_example')

10. TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

提示了,應該將 if grad: 修改為 if grad is not None。

11. AttributeError: module 'tensorflow' has no attribute 'merge_all_summaries'。

此處錯誤來自於cifar10_train.py中的 summary_op = tf.merge_all_summaries(),將其修改為:

 summary_op = tf.summary.merge_all()

12. AttributeError: module 'tensorflow.python.training.training' has no attribute 'SummaryWriter'。

這個錯誤來自於cifar10_train.py的

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

將其修改為:

summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

13. WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated. Pass a `Graph` object instead, such as `sess.graph`.

這個警告來自於cifar10_train.py的

summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

將其修改為:

summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                            sess.graph)

到這里為止,才能成功的運行這個例子,上述錯誤都是由於tf高低版本不兼容導致的,代碼本身沒有問題,只是在高版本的tf做了修改,比如本機的版本是1.7.0,某些低版本應該沒有問題。

以下是運行的log

2018-05-05 13:13:40.782676: step 0, loss = 4.68 (0.1 examples/sec; 918.057 sec/batch)
2018-05-05 13:14:37.995829: step 10, loss = 4.66 (19.7 examples/sec; 6.510 sec/batch)
2018-05-05 13:15:38.571923: step 20, loss = 4.64 (19.3 examples/sec; 6.618 sec/batch)
2018-05-05 13:16:37.660061: step 30, loss = 4.62 (20.4 examples/sec; 6.260 sec/batch)
2018-05-05 13:17:35.194066: step 40, loss = 4.60 (22.7 examples/sec; 5.639 sec/batch)
2018-05-05 13:18:36.177244: step 50, loss = 4.58 (22.5 examples/sec; 5.699 sec/batch)
2018-05-05 13:19:37.775057: step 60, loss = 4.57 (20.9 examples/sec; 6.122 sec/batch)
2018-05-05 13:20:38.255898: step 70, loss = 4.55 (21.0 examples/sec; 6.081 sec/batch)
2018-05-05 13:21:39.074639: step 80, loss = 4.53 (18.5 examples/sec; 6.929 sec/batch)
2018-05-05 13:22:42.469230: step 90, loss = 4.51 (21.9 examples/sec; 5.858 sec/batch)
2018-05-05 13:23:43.102476: step 100, loss = 4.50 (20.5 examples/sec; 6.236 sec/batch)
2018-05-05 13:24:53.920811: step 110, loss = 4.48 (19.1 examples/sec; 6.708 sec/batch)
2018-05-05 13:25:55.722164: step 120, loss = 4.46 (21.0 examples/sec; 6.097 sec/batch)
2018-05-05 13:26:58.607399: step 130, loss = 4.44 (20.8 examples/sec; 6.153 sec/batch)
2018-05-05 13:27:56.598621: step 140, loss = 4.42 (19.4 examples/sec; 6.589 sec/batch)
2018-05-05 13:28:57.043367: step 150, loss = 4.41 (20.9 examples/sec; 6.117 sec/batch)
2018-05-05 13:30:00.026865: step 160, loss = 4.39 (19.3 examples/sec; 6.640 sec/batch)
2018-05-05 13:30:57.701242: step 170, loss = 4.38 (19.2 examples/sec; 6.677 sec/batch)
2018-05-05 13:31:54.940464: step 180, loss = 4.36 (24.6 examples/sec; 5.210 sec/batch)
2018-05-05 13:32:54.969103: step 190, loss = 4.34 (20.9 examples/sec; 6.119 sec/batch)
2018-05-05 13:33:57.856344: step 200, loss = 4.32 (20.2 examples/sec; 6.340 sec/batch)
2018-05-05 13:35:03.489890: step 210, loss = 4.31 (21.5 examples/sec; 5.966 sec/batch)

上面括號里面的兩個數字表示的是每秒跑了多少張圖片和多少秒跑了一個batch,兩者相乘約等於一個batch的圖片數目128。可以從代碼中看出來,如下代碼所示:

if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

由於電腦配置低,跑起來太慢了,20000次,很耗時,所以最終的結果就不貼了。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM