tensorflow API _ 2 (tf.app.flags.FLAGS)


tf.app.flags.FLAGS 的使用,主要是在用命令行執行程序時,需要傳些參數,代碼如下:
新建一個名為:app_flags.py 的文件。

#coding:utf-8 
import tensorflow as tf 
FLAGS = tf.app.flags.FLAGS 
tf.app.flags.DEFINE_string("train_data_path", "/home/libo3/train.txt", "training data dir") 
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir") 
tf.app.flags.DEFINE_integer("max_sentence_len", 80, "max num of tokens per query") 
tf.app.flags.DEFINE_integer("embedding_size", 50, "embedding size") 
tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate") 
  
def main(unused_argv): 
    train_data_path = FLAGS.train_data_path 
    print("train_data_path", train_data_path) 
    max_sentence_len = FLAGS.max_sentence_len 
    print("max_sentence_len", max_sentence_len) 
    embdeeing_size = FLAGS.embedding_size 
    print("embedding_size", embdeeing_size) 
    abc = tf.add(max_sentence_len, embdeeing_size) 
 
    init = tf.global_variables_initializer() 
 
    #with tf.Session() as sess: 
        #sess.run(init) 
        #print("abc", sess.run(abc)) 
 
    sv = tf.train.Supervisor(logdir=FLAGS.log_dir, init_op=init) 
    with sv.managed_session() as sess: 
        print("abc:", sess.run(abc)) 
 
        # sv.saver.save(sess, "/home/yongcai/tmp/") 
  
# 使用這種方式保證了,如果此文件被其他文件 import的時候,不會執行main 函數 
if __name__ == '__main__': 
    tf.app.run()   # 解析命令行參數,調用main 函數 main(sys.argv) 

 

調用方法:

其中參數可以根據需求進行修改。

  1. python app_flags.py --train_data_path <絕對路徑 train.txt> --max_sentence_len 100 --embedding_size 100 --learning_rate 0.05 

如果這樣調用:

  1. python app_flags.py  

則會執行程序時會自動調用程序中 default 中的參數。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM