tensorflow的代碼中,常常會有tf.app.run()作為入口的寫法,如下:
...
# 此處省略n行代碼
...
def main(_):
...
if __name__ == "__main__":
tf.app.run()
好的,那我們就進入tf.app.run()這個函數里康康到底是什么樣子的。
@tf_export(v1=['app.run'])
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
main = main or _sys.modules['__main__'].main
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
從這里看,應該是輸入一個函數對象作為參數,用於程序的運行,如果沒輸入函數就使用默認的_sys.modules['main'].main,那就建兩個測試文件,看看執行模塊從其它模塊導入的run函數中,_sys.modules['main']的結果是什么
# file1
import sys as _sys
def run():
print(_sys.modules['__main__'])
print(_sys.modules['__main__'].main)
# file2
from file1 import run
def main(_):
pass
if __name__ == '__main__':
run()
run函數中pirnt的結果為:
<module 'main' from 'D:/data/Projects/demo/file2.py'>
<function main at 0x00000142C79B2F28>
可見,tf.app.run函數中的默認函數即為file2中main函數對象,當主函數不為main時為test時,需將test函數作為參數傳入tf.app.run()中,如下
...
# 此處省略n行代碼
...
def test(_):
# 主函數
...
if __name__ == "__main__":
tf.app.run(test)
需要注意的一點是,使用tf.app.run作為入口時,主函數至少需要一個參數,常見的寫法就是放個_作為參數,寫成main(_),防止報錯。若未加上參數,就會報TypeError: main() takes 0 positional arguments but 1 was given