很多時候在運行python代碼的時候我們需要從外部定義參數,從而避免每次都需要改動代碼。所以一般我們都會使用 argparse
這個庫。其實TensorFlow也提供了這個功能,那就是 tf.app.flags
。
使用方法很簡單
tf.app.flags.DEFINE_boolean("param_name", "default_val", "description")
上面給出的是定義一個bool變量,第一個參數是指參數名,第二個是默認值,第三個是對該變量的描述,如果不想描述可以直接用 ""。
除了bool類,我們還可以定義其他的類型數據,如:
- tf.app.flags.DEFINE_integer
- tf.app.flags.DEFINE_float
- tf.app.flags.DEFINE_string
那么如何使用呢?完整示例(假設文件名為test.py)如下:
# coding=utf-8
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
tf.app.flags.DEFINE_integer('data', 10, "")
tf.app.flags.DEFINE_boolean("istrain", True, "")
def main(_):
print("{}".format(FLAGS.data))
print("{}".format(FLAGS.istrain))
if __name__ == '__main__':
tf.app.run()
5-6行:首先需要定義一個tf.app.flags,然后定義一個FLAGS,它是用來解析傳入的參數的。
7-8行:定義了兩個變量,分別是整型變量和bool型。
10-12行:注意使用tf.app.flags一般需要使用main函數作為入口,然后再仔細看main函數是需要傳參數的(雖然不知道為什么),否則會出現如下報錯信息:
Traceback (most recent call last):
File "d:/Code/AutoML/enas/testfiles/batch_test.py", line 19, in <module>
tf.app.run()
File "D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
TypeError: main() takes 0 positional arguments but 1 was given
所以我們可以傳入一個無意義的 _ 來解決這個問題。
下面來看看如何運行這個文件:
python test.py --data 20
>>>
20
True
可以看到如果不穿入參數則按默認值處理,否則根據傳入的值對變量進行更新。
另外對於bool型變量有個地方需要注意,因為bool只有True和False,所以無論bool變量默認值為True還是False,在變量面前加個no后都取False,,其他類型的沒有這個特權,示例如下:
python test.py --nodata --noistrain
>>>
10
False
將istrain默認值設為False,運行結果和上面一樣。