原文地址:https://blog.csdn.net/mrr1ght/article/details/81011280 。本文有刪減。
tf.train.SessionRunHook()是一個類;用來定義Hooks;
Hooks是什么,官方文檔中關於training hooks的定義是:
Hooks are tools that run in the process of training/evaluation of the model.
Hooks是在模型訓練/測試過程中的工具。Pytorch中也經常會有這個概念出現,其實也就跟keras里的callbacks一樣,hook和callback都是在訓練過程中執行特定的任務。
例如判斷是否需要停止訓練的EarlyStopping;改變學習率的LearningRateScheduler,他們都有一個共性,就是在每個step開始/結束或者每個epoch開始/結束時需要執行某個操作。如每個epoch結束都保存一次checkpoint;每個epoch結束時都判斷一次loss有沒有下降,如果loss沒有下降的輪數大於提取設定的閾值,就終止訓練。當然以上的功能我們都可以自己完全重頭實現。但是這些keras和tersorflow提供了更好的工具就是hook和callback,並且一些常用的功能都已經實現好了。說到底每個hook和callback都是按照固定格式定義了在每個step開始/結束要執行的操作,每個epoch開始/結束執行的操作。
Hooks都是繼承自父類tf.train.SessionRunHook()
,首先看一下這個父類的定義源碼;
tf.train.SessionRunHook()定義
tf.train.SessionRunHook()
類定義在tensorflow/python/training/session_run_hook.py
,類中每個函數的作用與什么時候調用都已加入函數注釋中;
class SessionRunHook(object):
"""Hook to extend calls to MonitoredSession.run()."""
def begin(self):
"""再創建會話之前調用
調用begin()時,default graph會被創建,
可在此處向default graph增加新op,begin()調用后,default graph不能再被修改
"""
pass
def after_create_session(self, session, coord): # pylint: disable=unused-argument
"""tf.Session被創建后調用
調用后會指示所有的Hooks有一個新的會話被創建
Args:
session: A TensorFlow Session that has been created.
coord: A Coordinator object which keeps track of all threads.
"""
pass
def before_run(self, run_context): # pylint: disable=unused-argument
"""調用在每個sess.run()執行之前
可以返回一個tf.train.SessRunArgs(op/tensor),在即將運行的會話中加入這些op/tensor;
加入的op/tensor會和sess.run()中已定義的op/tensor合並,然后一起執行;
Args:
run_context: A `SessionRunContext` object.
Returns:
None or a `SessionRunArgs` object.
"""
return None
def after_run(self,
run_context, # pylint: disable=unused-argument
run_values): # pylint: disable=unused-argument
"""調用在每個sess.run()之后
參數run_values是befor_run()中要求的op/tensor的返回值;
可以調用run_context.qeruest_stop()用於停止迭代
sess.run拋出任何異常after_run不會被調用
Args:
run_context: A `SessionRunContext` object.
run_values: A SessionRunValues object.
"""
pass
def end(self, session): # pylint: disable=unused-argument
"""在會話結束時調用
end()常被用於Hook想要執行最后的操作,如保存最后一個checkpoint
如果sess.run()拋出除了代表迭代結束的OutOfRange/StopIteration異常外,
end()不會被調用
Args:
session: A TensorFlow Session that will be soon closed.
"""
pass
tf.train.SessionRunHook()
類中定義的方法的參數run_context
,run_values
,run_args
,包含sess.run()
會話運行所需的一切信息,
run_context
:類tf.train.SessRunContext
的實例run_values
:類tf.train.SessRunValues
的實例run_args
:類tf.train.SessRunArgs
的實例.
這三個類會在下面詳細介紹
tf.train.SessionRunHook()的使用
(1)可以使用tf中已經預定義好的Hook,其都是tf.train.SessionRunHook()的子類;如
- StopAtStepHook:設置用於停止迭代的max_step或num_step,兩者只能設置其一
- NanTensorHook:如果loss的值為Nan,則停止訓練;
- tensorflow中有許多預定義的Hook,想了解更多的同學可以去官方文檔tf.train.下查看
(2)也可用tf.train.SessionRunHook()定義自己的Hook,並重寫類中的方法;然后把想要使用的Hook(預定義好的或者自己定義的)放到tf.train.MonitorTrainingSession()參數[Hook]列表中;
關於tf.train.MonitorTrainingSession()
參見tf.train.MonitoredTrainingSession()解析。
給一個定義自己Hook的栗子,來自cifar10
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time#duration持續的時間
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
SessRunContext/SessRunValues/SessRunArgs
這三個類都服務於sess.run(),區別如下:
- tf.train.SessRunContext和tf.train.SessRunArgs提供會話運行所需的信息,
- tf.train.SessRunValues保存會話運行的結果
(1) tf.train.SessRunArgs類
提供給會話運行的參數,與sess.run()參數定義一樣:
fethes,feeds,option
(2) tf.train.SessRunValues
用於保存sess.run()的結果,其中resluts是sess.run()返回值中對應於SessRunArgs()的返回值,
(3) tf.train.SessRunContext
SessRunContext包含sess.run()所需的一切信息
屬性:
- original_args:sess.run所需的參數,是一個tf.train.SessRunArgs實例
- session:指定要運行的會話
- stop_request:返回一個bool值,用於判斷是否停止迭代;
方法:
equest_stop(): 設置_stop_request值為True
cifar10 中的運用實例
tf.train.SessionRunHook()和tf.train.MonitorTrainingSession()一般一起使用,下面是cifar10中的使用實例
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time#duration持續的時間
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
#monitored 被監控的
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)