我自己對mnist官方例程進行了部分注解,希望分享出來有助於入門選手更好理解tensorflow的運行機制,可以拷貝到IDE再調試看看,看看具體數據流向還有一部分tensorflow里面用到的庫。
我用的是pip安裝的tensorflow-GPU-1.13,這段源碼原始位置在https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py
代碼:
1 from __future__ import absolute_import 2 from __future__ import division 3 from __future__ import print_function 4 5 #absl是python標准庫內的 6 from absl import app as absl_app 7 from absl import flags 8 9 import tensorflow as tf # pylint: disable=g-bad-import-order 10 11 from official.mnist import dataset 12 from official.utils.flags import core as flags_core 13 from official.utils.logs import hooks_helper 14 from official.utils.misc import distribution_utils 15 from official.utils.misc import model_helpers 16 17 18 LEARNING_RATE = 1e-4 19 20 #參數默認data_format = 'channels_first' 21 def create_model(data_format): 22 """Model to recognize digits in the MNIST dataset. 23 24 Network structure is equivalent to: 25 https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py 26 and 27 https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py 28 29 But uses the tf.keras API. 30 31 Args: 32 data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is 33 typically faster on GPUs while 'channels_last' is typically faster on 34 CPUs. See 35 https://www.tensorflow.org/performance/performance_guide#data_formats 36 37 Returns: 38 A tf.keras.Model. 39 """ 40 41 #data_format:一個字符串,可以是channels_last(默認)或channels_first,\ 42 # 表示輸入中維度的順序,channels_last對應於具有形狀(batch, height, width, channels)\ 43 # 的輸入,而channels_first對應於具有形狀(batch, channels, height, width)的輸入. 44 #這里感覺輸入只有三個維度,默認是單通道圖? 45 if data_format == 'channels_first': 46 input_shape = [1, 28, 28] 47 else: 48 assert data_format == 'channels_last' 49 input_shape = [28, 28, 1] 50 51 #將tf.keras.layers.MaxPooling2D傳遞給max_pool 52 l = tf.keras.layers 53 max_pool = l.MaxPooling2D( 54 (2, 2), (2, 2), padding='same', data_format=data_format) 55 # The model consists of a sequential chain of layers, so tf.keras.Sequential 56 # (a subclass of tf.keras.Model) makes for a compact description. 57 return tf.keras.Sequential( 58 [ 59 #輸入層確保輸入的大小符合網絡需要[28, 28]->[1, 28, 28] 60 l.Reshape( 61 target_shape=input_shape, 62 input_shape=(28 * 28,)), 63 #卷積 64 l.Conv2D( 65 32,#filters:整數, 輸出空間的維數(即卷積中的濾波器數),就是卷積核個數 66 5,#卷積核大小,這里是5x5 67 padding='same', 68 data_format=data_format, 69 activation=tf.nn.relu), 70 #最大pooling 71 max_pool, 72 #卷積 73 l.Conv2D( 74 64, 75 5, 76 padding='same', 77 data_format=data_format, 78 activation=tf.nn.relu), 79 # 最大pooling 80 max_pool, 81 #在保留第0軸的情況下對輸入的張量進行Flatten(扁平化),拉直? 82 l.Flatten(), 83 #fc 1024 -> units: 該層的神經單元結點數。 84 l.Dense(1024, activation=tf.nn.relu), 85 l.Dropout(0.4), 86 #fc輸出 87 l.Dense(10) 88 ]) 89 90 #添加了很多參數,指定了一部分的值,數據url,模型url,batch_size等等 91 def define_mnist_flags(): 92 flags_core.define_base() 93 flags_core.define_performance(num_parallel_calls=False) 94 flags_core.define_image() 95 flags.adopt_module_key_flags(flags_core) 96 #自定義項參數都在這里設置了 97 flags_core.set_defaults(data_dir='./tmp/mnist_data', 98 model_dir='./tmp/mnist_model', 99 batch_size=100, 100 train_epochs=40, 101 stop_threshold=0.998) 102 103 104 def model_fn(features, labels, mode, params): 105 """The model_fn argument for creating an Estimator.""" 106 # 翻譯成中文,注釋的意思就是添加一個data_format的參數,下面的Estimator類需要用到 107 model = create_model(params['data_format']) 108 image = features 109 # 來判斷一個對象是否是一個已知的類型。 110 if isinstance(image, dict): 111 image = features['image'] 112 113 #測試模式 114 if mode == tf.estimator.ModeKeys.PREDICT: 115 logits = model(image, training=False) 116 predictions = { 117 'classes': tf.argmax(logits, axis=1), 118 'probabilities': tf.nn.softmax(logits), 119 } 120 #如果只是測試到這里就返回了 121 return tf.estimator.EstimatorSpec( 122 mode=tf.estimator.ModeKeys.PREDICT, 123 predictions=predictions, 124 export_outputs={ 125 'classify': tf.estimator.export.PredictOutput(predictions) 126 }) 127 128 #訓練模式 129 if mode == tf.estimator.ModeKeys.TRAIN: 130 #設置LEARNING_RATE 131 optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) 132 133 logits = model(image, training=True) 134 loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 135 accuracy = tf.metrics.accuracy( 136 labels=labels, predictions=tf.argmax(logits, axis=1)) 137 138 # Name tensors to be logged with LoggingTensorHook. 139 tf.identity(LEARNING_RATE, 'learning_rate') 140 tf.identity(loss, 'cross_entropy') 141 tf.identity(accuracy[1], name='train_accuracy') 142 143 # Save accuracy scalar to Tensorboard output. 144 tf.summary.scalar('train_accuracy', accuracy[1]) 145 146 return tf.estimator.EstimatorSpec( 147 mode=tf.estimator.ModeKeys.TRAIN, 148 loss=loss, 149 train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step())) 150 if mode == tf.estimator.ModeKeys.EVAL: 151 logits = model(image, training=False) 152 loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 153 return tf.estimator.EstimatorSpec( 154 mode=tf.estimator.ModeKeys.EVAL, 155 loss=loss, 156 eval_metric_ops={ 157 'accuracy': 158 tf.metrics.accuracy( 159 labels=labels, predictions=tf.argmax(logits, axis=1)), 160 }) 161 162 163 def run_mnist(flags_obj): 164 """Run MNIST training and eval loop. 165 166 Args: 167 flags_obj: An object containing parsed flag values. 168 """ 169 170 #apply_clean是官方例程里面提供的用來清理現存model的方法,\ 171 # 取決於flags_obj.clean(True則清理flags_obj.model_dir內的文件) 172 model_helpers.apply_clean(flags_obj) 173 174 #把自定義的實現傳給tf.estimator.Estimator 175 model_function = model_fn 176 177 #tf.ConfigProto()主要的作用是配置tf.Session的運算方式,比如gpu運算或者cpu運算 178 session_config = tf.ConfigProto( 179 #設置線程一個操作內部並行運算的線程數,比如矩陣乘法,如果設置為0,則表示以最優的線程數處理 180 inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, 181 #設置多個操作並行運算的線程數,比如 c = a + b,d = e + f . 可以並行運算 182 intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, 183 #有時候,不同的設備,它的cpu和gpu是不同的,如果將這個選項設置成True,\ 184 # 那么當運行設備不滿足要求時,會自動分配GPU或者CPU 185 allow_soft_placement=True) 186 187 #獲取gpu數目,優化算法等,用於優化 188 distribution_strategy = distribution_utils.get_distribution_strategy( 189 flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) 190 191 #所有輸出(檢查點,事件文件等)都被寫入model_dir或其子目錄.如果model_dir未設置,則使用臨時目錄. 192 #可以通過RunConfig對象(包含了有關執行環境的信息)傳遞config參數.它被傳遞給model_fn,\ 193 # 如果model_fn有一個名為“config”的參數(和輸入函數以相同的方式).如果該config參數未被傳遞,\ 194 # 則由Estimator進行實例化.不傳遞配置意味着使用對本地執行有用的默認值.Estimator使配置對模型\ 195 # 可用(例如,允許根據可用的工作人員數量進行專業化),並且還使用其一些字段來控制內部,特別是關於檢查點 196 run_config = tf.estimator.RunConfig( 197 train_distribute=distribution_strategy, session_config=session_config) 198 199 data_format = flags_obj.data_format 200 #channels_first,即(3,128,128,128)通道數在最前面 201 #channels_last,即(128,128,128,3)通道數在最后面 202 if data_format is None: 203 data_format = ('channels_first' 204 if tf.test.is_built_with_cuda() else 'channels_last')#判斷安裝的TF是否支持GPU 205 206 #estimator類對TensorFlow模型進行訓練和計算. 207 #Estimator對象包裝由model_fn指定的模型,其中,給定輸入和其他一些參數,返回需要進行訓練、計算,或預測的操作. 208 mnist_classifier = tf.estimator.Estimator( 209 #這個model_fn是參數名而已 210 model_fn=model_function,#模型對象 211 model_dir=flags_obj.model_dir,#模型目錄,如果為空會創建一個臨時目錄 212 #猜測會去model_dir中尋找數據 213 config=run_config,#運行的一些參數 214 params={ 215 'data_format': data_format,#數據類型 216 }) 217 218 #這里定義了兩個內部函數,只能被這個語句塊的內部調用 219 # Set up training and evaluation input functions. 220 def train_input_fn(): 221 """Prepare data for training.""" 222 223 # When choosing shuffle buffer sizes, larger sizes result in better 224 # randomness, while smaller sizes use less memory. MNIST is a small 225 # enough dataset that we can easily shuffle the full epoch. 226 ds = dataset.train(flags_obj.data_dir) 227 ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size) 228 229 # Iterate through the dataset a set number (`epochs_between_evals`) of times 230 # during each training session. 231 ds = ds.repeat(flags_obj.epochs_between_evals) 232 return ds 233 234 def eval_input_fn(): 235 return dataset.test(flags_obj.data_dir).batch( 236 flags_obj.batch_size).make_one_shot_iterator().get_next() 237 238 # Set up hook that outputs training logs every 100 steps. 239 train_hooks = hooks_helper.get_train_hooks( 240 flags_obj.hooks, model_dir=flags_obj.model_dir, 241 batch_size=flags_obj.batch_size) 242 243 # Train and evaluate model. 244 for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals): 245 #訓練一次,驗證一次 246 mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks) 247 eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) 248 print('\nEvaluation results:\n\t%s\n' % eval_results) 249 250 #如果eval_results['accuracy'] >= flags_obj.stop_threshold 說明模型訓練好了 251 if model_helpers.past_stop_threshold(flags_obj.stop_threshold, 252 eval_results['accuracy']): 253 break 254 255 # Export the model 256 if flags_obj.export_dir is not None: 257 #預分配內存,等待數據進入 258 image = tf.placeholder(tf.float32, [None, 28, 28]) 259 input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 260 'image': image, 261 }) 262 #輸出模型 263 mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn) 264 265 266 def main(_): 267 run_mnist(flags.FLAGS) 268 269 270 if __name__ == '__main__': 271 #日志 272 tf.logging.set_verbosity(tf.logging.INFO) 273 #給flags.FLAGS添加了很多參數項目 274 define_mnist_flags() 275 #帶參數的啟動 276 absl_app.run(main)