python基礎--absl.flags


之前在tensorflow的mnist例程中看到了使用 absl.flags的方法來載入和解析參數的,出於學習的目的,就自己試驗了一下,

代碼如下:

 1 # *_*coding:utf-8 *_*
 2 # athor:auto
 3 
 4 import sys, os
 5 from absl import app
 6 from absl import flags
 7 from official.utils.flags import core as flags_core
 8 
 9 
10 FLAGS = flags.FLAGS
11 flags.DEFINE_string('gpu', None, 'comma separated list of GPU to use.')
12 
13 
14 def flagtest(argv):
15     del argv
16     if FLAGS.gpu:
17         print("gpu is %s" % FLAGS.gpu)
18         os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
19     else:
20         print('Please assign GPUs.')
21         exit()
22 
23 def main(argv):
24     flags_core.define_base()
25     flags_core.define_performance(num_parallel_calls=False)
26     flags_core.define_image()
27     flags.adopt_module_key_flags(flags_core)
28 
29 if __name__ == '__main__':
30     app.run(flagtest)
View Code

其中main中的幾個調用都是源自於tensorflow的model/official,里面的函數大多是model/official/utils/flags/core.py內定義好的一些默認參數。
在mnist例子中還可以這樣添加自定義項:

  flags_core.set_defaults(data_dir='./tmp/mnist_data',
                          model_dir='./tmp/mnist_model',
                          batch_size=100,
                          train_epochs=40,
                          stop_threshold=0.998)

  

 

參考:

https://blog.csdn.net/faith_binyang/article/details/80551941


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM