Tensorflow項目中--FLAGS=tf.flags.FLAGS


  最近看CycleGAN的代碼,看到代碼里有FLAGS=tf.flags.FLAGS等語句,看不明白,查尋之余發現,這類語句在使用Tensorflow框架的項目里是常見的。並且在看代碼解釋時,找到一個博主關於這部分只是的梳理,有解釋加示例非常清楚,所以就直接應用該作者的文章。


內容包含如下幾個我們經常看到的幾個函數:

①tf.flags.DEFINE_xxx()

②FLAGS = tf.flags.FLAGS

③FLAGS._parse_flags()

簡單的說:

   用於幫助我們添加命令行的可選參數。也就是說可以不用反復修改源代碼中的參數,而是利用該函數可以實現在命令行中選擇需要設定或者修改的參數來運行程序。

舉個栗子:

程序train.py文件中的小部分代碼如下所示:

 1 FLAGS = tf.flags.FLAGS
 2 
 3 tf.flags.DEFINE_string('name', 'default', 'name of the model')
 4 tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
 5 tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
 6 tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
 7 tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')
 8 tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
 9 tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
10 tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')
11 tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training')
12 tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file')
13 tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train')
14 tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps')
15 tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps')
16 tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number')

在命令行中我們為了執行train.py文件,在命令行中輸入:

python train.py \
  --input_file data/shakespeare.txt  \         
  --name shakespeare \
  --num_steps 50 \
  --num_seqs 32 \
  --learning_rate 0.01 \
  --max_steps 20000

通過輸入不同的文件名、參數,可以快速完成程序的調參和加載訓練集的操作,不需要進入源碼中更改。


 

實踐操作

現在我們有如下代碼:

 1 import tensorflow as tf
 2 #取上述代碼中一部分進行實驗  
 3 tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')   
 4 tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
 5 tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
 6 
 7 #通過print()確定下面內容的功能
 8 FLAGS = tf.flags.FLAGS #FLAGS保存命令行參數的數據
 9 FLAGS._parse_flags() #將其解析成字典存儲到FLAGS.__flags中
10 print(FLAGS.__flags)
11 
12 print(FLAGS.num_seqs)
13 
14 print("\nParameters:")
15 for attr, value in sorted(FLAGS.__flags.items()):
16     print("{}={}".format(attr.upper(), value))
17 print("")

按照我現在編寫這個博客時間節點來說,第九行的 FLAGS._parse_flags()   在新版本的Tensorflow中不再使用了,如果因為版本造成編譯出錯,會返回AttributeError: _parse_flags。所以從另一個博主那看到新的的代碼為  FLAGS.flag_values_dict()   (解析成字典並且存儲到FLAGS.__flags中)。

注意點:

  1.   修改參數的方式
  2.   調用參數的方式
  3.   描述參數的方式
  4.   定義參數的類型

 

原作者:,鏈接:https://blog.csdn.net/zzq060143/article/details/81952848 

修改代碼的作者:   ,鏈接:https://blog.csdn.net/spring_willow/article/details/80115206

非常感謝作者的分享。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM