tensorflow--mnist注解


我自己對mnist官方例程進行了部分注解,希望分享出來有助於入門選手更好理解tensorflow的運行機制,可以拷貝到IDE再調試看看,看看具體數據流向還有一部分tensorflow里面用到的庫。
我用的是pip安裝的tensorflow-GPU-1.13,這段源碼原始位置在https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py

代碼:

  1 from __future__ import absolute_import
  2 from __future__ import division
  3 from __future__ import print_function
  4 
  5 #absl是python標准庫內的
  6 from absl import app as absl_app
  7 from absl import flags
  8 
  9 import tensorflow as tf  # pylint: disable=g-bad-import-order
 10 
 11 from official.mnist import dataset
 12 from official.utils.flags import core as flags_core
 13 from official.utils.logs import hooks_helper
 14 from official.utils.misc import distribution_utils
 15 from official.utils.misc import model_helpers
 16 
 17 
 18 LEARNING_RATE = 1e-4
 19 
 20 #參數默認data_format = 'channels_first'
 21 def create_model(data_format):
 22   """Model to recognize digits in the MNIST dataset.
 23 
 24   Network structure is equivalent to:
 25   https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
 26   and
 27   https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
 28 
 29   But uses the tf.keras API.
 30 
 31   Args:
 32     data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
 33       typically faster on GPUs while 'channels_last' is typically faster on
 34       CPUs. See
 35       https://www.tensorflow.org/performance/performance_guide#data_formats
 36 
 37   Returns:
 38     A tf.keras.Model.
 39   """
 40 
 41   #data_format:一個字符串,可以是channels_last(默認)或channels_first,\
 42   # 表示輸入中維度的順序,channels_last對應於具有形狀(batch, height, width, channels)\
 43   # 的輸入,而channels_first對應於具有形狀(batch, channels, height, width)的輸入.
 44   #這里感覺輸入只有三個維度,默認是單通道圖?
 45   if data_format == 'channels_first':
 46     input_shape = [1, 28, 28]
 47   else:
 48     assert data_format == 'channels_last'
 49     input_shape = [28, 28, 1]
 50 
 51   #將tf.keras.layers.MaxPooling2D傳遞給max_pool
 52   l = tf.keras.layers
 53   max_pool = l.MaxPooling2D(
 54       (2, 2), (2, 2), padding='same', data_format=data_format)
 55   # The model consists of a sequential chain of layers, so tf.keras.Sequential
 56   # (a subclass of tf.keras.Model) makes for a compact description.
 57   return tf.keras.Sequential(
 58       [
 59           #輸入層確保輸入的大小符合網絡需要[28, 28]->[1, 28, 28]
 60           l.Reshape(
 61               target_shape=input_shape,
 62               input_shape=(28 * 28,)),
 63           #卷積
 64           l.Conv2D(
 65               32,#filters:整數, 輸出空間的維數(即卷積中的濾波器數),就是卷積核個數
 66               5,#卷積核大小,這里是5x5
 67               padding='same',
 68               data_format=data_format,
 69               activation=tf.nn.relu),
 70           #最大pooling
 71           max_pool,
 72           #卷積
 73           l.Conv2D(
 74               64,
 75               5,
 76               padding='same',
 77               data_format=data_format,
 78               activation=tf.nn.relu),
 79           # 最大pooling
 80           max_pool,
 81           #在保留第0軸的情況下對輸入的張量進行Flatten(扁平化),拉直?
 82           l.Flatten(),
 83           #fc 1024 -> units: 該層的神經單元結點數。
 84           l.Dense(1024, activation=tf.nn.relu),
 85           l.Dropout(0.4),
 86           #fc輸出
 87           l.Dense(10)
 88       ])
 89 
 90 #添加了很多參數,指定了一部分的值,數據url,模型url,batch_size等等
 91 def define_mnist_flags():
 92   flags_core.define_base()
 93   flags_core.define_performance(num_parallel_calls=False)
 94   flags_core.define_image()
 95   flags.adopt_module_key_flags(flags_core)
 96   #自定義項參數都在這里設置了
 97   flags_core.set_defaults(data_dir='./tmp/mnist_data',
 98                           model_dir='./tmp/mnist_model',
 99                           batch_size=100,
100                           train_epochs=40,
101                           stop_threshold=0.998)
102 
103 
104 def model_fn(features, labels, mode, params):
105   """The model_fn argument for creating an Estimator."""
106   # 翻譯成中文,注釋的意思就是添加一個data_format的參數,下面的Estimator類需要用到
107   model = create_model(params['data_format'])
108   image = features
109   # 來判斷一個對象是否是一個已知的類型。
110   if isinstance(image, dict):
111     image = features['image']
112 
113   #測試模式
114   if mode == tf.estimator.ModeKeys.PREDICT:
115     logits = model(image, training=False)
116     predictions = {
117         'classes': tf.argmax(logits, axis=1),
118         'probabilities': tf.nn.softmax(logits),
119     }
120     #如果只是測試到這里就返回了
121     return tf.estimator.EstimatorSpec(
122         mode=tf.estimator.ModeKeys.PREDICT,
123         predictions=predictions,
124         export_outputs={
125             'classify': tf.estimator.export.PredictOutput(predictions)
126         })
127 
128   #訓練模式
129   if mode == tf.estimator.ModeKeys.TRAIN:
130     #設置LEARNING_RATE
131     optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
132 
133     logits = model(image, training=True)
134     loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
135     accuracy = tf.metrics.accuracy(
136       labels=labels, predictions=tf.argmax(logits, axis=1))
137 
138     # Name tensors to be logged with LoggingTensorHook.
139     tf.identity(LEARNING_RATE, 'learning_rate')
140     tf.identity(loss, 'cross_entropy')
141     tf.identity(accuracy[1], name='train_accuracy')
142 
143     # Save accuracy scalar to Tensorboard output.
144     tf.summary.scalar('train_accuracy', accuracy[1])
145 
146     return tf.estimator.EstimatorSpec(
147         mode=tf.estimator.ModeKeys.TRAIN,
148         loss=loss,
149         train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
150   if mode == tf.estimator.ModeKeys.EVAL:
151     logits = model(image, training=False)
152     loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
153     return tf.estimator.EstimatorSpec(
154         mode=tf.estimator.ModeKeys.EVAL,
155         loss=loss,
156         eval_metric_ops={
157             'accuracy':
158                 tf.metrics.accuracy(
159                     labels=labels, predictions=tf.argmax(logits, axis=1)),
160         })
161 
162 
163 def run_mnist(flags_obj):
164   """Run MNIST training and eval loop.
165 
166   Args:
167     flags_obj: An object containing parsed flag values.
168   """
169 
170   #apply_clean是官方例程里面提供的用來清理現存model的方法,\
171   # 取決於flags_obj.clean(True則清理flags_obj.model_dir內的文件)
172   model_helpers.apply_clean(flags_obj)
173 
174   #把自定義的實現傳給tf.estimator.Estimator
175   model_function = model_fn
176 
177   #tf.ConfigProto()主要的作用是配置tf.Session的運算方式,比如gpu運算或者cpu運算
178   session_config = tf.ConfigProto(
179       #設置線程一個操作內部並行運算的線程數,比如矩陣乘法,如果設置為0,則表示以最優的線程數處理
180       inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
181       #設置多個操作並行運算的線程數,比如 c = a + b,d = e + f . 可以並行運算
182       intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
183       #有時候,不同的設備,它的cpu和gpu是不同的,如果將這個選項設置成True,\
184       # 那么當運行設備不滿足要求時,會自動分配GPU或者CPU
185       allow_soft_placement=True)
186 
187   #獲取gpu數目,優化算法等,用於優化
188   distribution_strategy = distribution_utils.get_distribution_strategy(
189       flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
190 
191   #所有輸出(檢查點,事件文件等)都被寫入model_dir或其子目錄.如果model_dir未設置,則使用臨時目錄.
192   #可以通過RunConfig對象(包含了有關執行環境的信息)傳遞config參數.它被傳遞給model_fn,\
193   # 如果model_fn有一個名為“config”的參數(和輸入函數以相同的方式).如果該config參數未被傳遞,\
194   # 則由Estimator進行實例化.不傳遞配置意味着使用對本地執行有用的默認值.Estimator使配置對模型\
195   # 可用(例如,允許根據可用的工作人員數量進行專業化),並且還使用其一些字段來控制內部,特別是關於檢查點
196   run_config = tf.estimator.RunConfig(
197       train_distribute=distribution_strategy, session_config=session_config)
198 
199   data_format = flags_obj.data_format
200   #channels_first,即(3,128,128,128)通道數在最前面
201   #channels_last,即(128,128,128,3)通道數在最后面
202   if data_format is None:
203     data_format = ('channels_first'
204                    if tf.test.is_built_with_cuda() else 'channels_last')#判斷安裝的TF是否支持GPU
205 
206   #estimator類對TensorFlow模型進行訓練和計算.
207   #Estimator對象包裝由model_fn指定的模型,其中,給定輸入和其他一些參數,返回需要進行訓練、計算,或預測的操作.
208   mnist_classifier = tf.estimator.Estimator(
209       #這個model_fn是參數名而已
210       model_fn=model_function,#模型對象
211       model_dir=flags_obj.model_dir,#模型目錄,如果為空會創建一個臨時目錄
212       #猜測會去model_dir中尋找數據
213       config=run_config,#運行的一些參數
214       params={
215           'data_format': data_format,#數據類型
216       })
217 
218   #這里定義了兩個內部函數,只能被這個語句塊的內部調用
219   # Set up training and evaluation input functions.
220   def train_input_fn():
221     """Prepare data for training."""
222 
223     # When choosing shuffle buffer sizes, larger sizes result in better
224     # randomness, while smaller sizes use less memory. MNIST is a small
225     # enough dataset that we can easily shuffle the full epoch.
226     ds = dataset.train(flags_obj.data_dir)
227     ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
228 
229     # Iterate through the dataset a set number (`epochs_between_evals`) of times
230     # during each training session.
231     ds = ds.repeat(flags_obj.epochs_between_evals)
232     return ds
233 
234   def eval_input_fn():
235     return dataset.test(flags_obj.data_dir).batch(
236         flags_obj.batch_size).make_one_shot_iterator().get_next()
237 
238   # Set up hook that outputs training logs every 100 steps.
239   train_hooks = hooks_helper.get_train_hooks(
240       flags_obj.hooks, model_dir=flags_obj.model_dir,
241       batch_size=flags_obj.batch_size)
242 
243   # Train and evaluate model.
244   for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
245     #訓練一次,驗證一次
246     mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
247     eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
248     print('\nEvaluation results:\n\t%s\n' % eval_results)
249 
250     #如果eval_results['accuracy'] >= flags_obj.stop_threshold 說明模型訓練好了
251     if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
252                                          eval_results['accuracy']):
253       break
254 
255   # Export the model
256   if flags_obj.export_dir is not None:
257     #預分配內存,等待數據進入
258     image = tf.placeholder(tf.float32, [None, 28, 28])
259     input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
260         'image': image,
261     })
262     #輸出模型
263     mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
264 
265 
266 def main(_):
267   run_mnist(flags.FLAGS)
268 
269 
270 if __name__ == '__main__':
271   #日志
272   tf.logging.set_verbosity(tf.logging.INFO)
273   #給flags.FLAGS添加了很多參數項目
274   define_mnist_flags()
275   #帶參數的啟動
276   absl_app.run(main)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM