Estimator是Tensorflow的高階API。除了Tensorflow官方定義的內置Estimator之外,用戶也可以實現自定義的Estimator。
Estimator定義
Estimator的構造函數如下:
def __init__(self,
model_fn, # 定義模型,根據不同的模式分別定義訓練、評估和預測的圖。
model_dir=None, # 模型導出目錄
config=None, # 配置參數
params=None, # 自定義Estimator的額外參數
warm_start_from=None): # 模型熱啟動
其中最核心的參數為model_fn
,其接口如下
def _model_fn(features, # 特征,可以是Tensor或dict of Tensor
labels, # 標簽
mode, # 模式
params, # 自定義參數,即上面Estimator構造函數中的params
config): # 配置參數
model_fn
會被Estimator多次調用,通過調用Tensorflow的layer來實現模型。通過模式字段(ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT)來判斷是訓練、評估還是預測階段,分別構造不同的圖。model_fn
的返回結構為EstimatorSpec
,使用其中的訓練、loss和預測的OP,Estimator就可以驅動完成訓練、評估和預測。
EstimatorSpec的定義如下
def __new__(cls,
mode, # 模式
predictions=None, # 預測的Tensor或dict,mode為PREDICT時必填。
loss=None, # loss Tensor,mode為TRAIN或EVAL時必填。
train_op=None, # 訓練OP,mode為TRAIN時必填。
eval_metric_ops=None, # 評估OP的dict
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None):
訓練
Estimator的訓練接口如下
def train(self,
input_fn, # 返回訓練特征和標簽的tuple
hooks=None, # 通過hook指定訓練過程中的自定義行為
steps=None, # 訓練步數
max_steps=None, ## 訓練總步數
saving_listeners=None):
with context.graph_mode():
hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
loss = self._train_model(input_fn, hooks, saving_listeners)
logging.info('Loss for final step: %s.', loss)
_train_model
根據不同的配置,分別走到分布式訓練和本地訓練的函數。
def _train_model(self, input_fn, hooks, saving_listeners):
if self._train_distribution:
return self._train_model_distributed(input_fn, hooks, saving_listeners)
else:
return self._train_model_default(input_fn, hooks, saving_listeners)
我們先看本地訓練的實現。
def _train_model_default(self, input_fn, hooks, saving_listeners):
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(
input_fn, ModeKeys.TRAIN))
worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, ModeKeys.TRAIN, self.config)
global_step_tensor = training_util.get_global_step(g)
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
hooks, global_step_tensor,
saving_listeners)
其流程為先創建global_step,然后調用input_fn
來得到訓練特征和標簽,調用model_fn
來得到訓練圖,最后進入training loop。
_get_features_and_labels_from_input_fn
最終會調用input_fn
,得到訓練特征和標簽。
with ops.device('/cpu:0'):
return input_fn(**kwargs)
_call_model_fn
會調用model_fn
,注意傳遞的參數為ModeKeys.TRAIN
,用於表征訓練階段。
def _call_model_fn(self, features, labels, mode, config):
model_fn_results = self._model_fn(features=features, **kwargs)
下面看_train_with_estimator_spec
的實現。
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
global_step_tensor, saving_listeners):
# 滿足條件則熱啟動
if self._warm_start_settings:
warm_starting_util.warm_start(*self._warm_start_settings)
# 創建Hook
worker_hooks.extend(hooks)
worker_hooks.append(training.NanTensorHook(estimator_spec.loss)
worker_hooks.append(training.LoggingTensorHook(...))
saver_hooks = [
h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
worker_hooks.extend(estimator_spec.training_hooks)
worker_hooks.append(training.SummarySaverHook(...))
worker_hooks.append(training.StepCounterHook(...))
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=estimator_spec.scaffold,
hooks=worker_hooks,
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=save_summary_steps,
config=self._session_config,
log_step_count_steps=log_step_count_steps) as mon_sess:
loss = None
any_step_done = False
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
any_step_done = True
if not any_step_done:
logging.warning('Training with estimator made no steps. '
'Perhaps input is empty or misspecified.')
return loss
前面主要在創建Hook,后面使用MonitoredTrainingSession進行Training loop。
評估
評估的接口為
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
其中input_fn
接口與訓練函數中的input_fn
有相同的接口,調用后返回評估用的特征和標簽。評估最終會調用到下面的函數
def _actual_eval(self,
input_fn,
strategy=None,
steps=None,
hooks=None,
checkpoint_path=None,
name=None):
...
def _evaluate():
(scaffold, update_op, eval_dict, all_hooks) = (
self._evaluate_build_graph(input_fn, hooks, checkpoint_path))
return self._evaluate_run(
checkpoint_path=checkpoint_path,
scaffold=scaffold,
update_op=update_op,
eval_dict=eval_dict,
all_hooks=all_hooks,
output_dir=self.eval_dir(name))
return _evaluate()
_evaluate_build_graph
的實現如下:
def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):
"""Builds the graph and related hooks to run evaluation."""
(scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
self._call_model_fn_eval(input_fn, self.config))
all_hooks = list(input_hooks)
all_hooks.extend(hooks)
all_hooks.extend(list(evaluation_hooks or []))
if scaffold and scaffold.local_init_op:
# 創建評估step
evaluation._get_or_create_eval_step() # pylint: disable=protected-access
scaffold = monitored_session.Scaffold(
local_init_op=control_flow_ops.group(
scaffold.local_init_op,
monitored_session.Scaffold.default_local_init_op()),
copy_from_scaffold=scaffold
)
return scaffold, update_op, eval_dict, all_hooks
_evaluate_build_graph
會調用_call_model_fn_eval
,進行評估構圖,然后返回scaffold。
def _call_model_fn_eval(self, input_fn, config):
"""Call model_fn for evaluation and handle return values."""
features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
input_fn, ModeKeys.EVAL)
estimator_spec = self._call_model_fn(
features, labels, ModeKeys.EVAL, config)
eval_metric_ops = _verify_and_create_loss_metric(
estimator_spec.eval_metric_ops, estimator_spec.loss)
update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)
return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,
input_hooks, update_op, eval_dict)
_call_model_fn_eval
流程為從input_fn
獲取評估用的特征和標簽,然后調用model_fn
進行評估構圖。
_actual_eval
調用完_evaluate_build_graph
之后,接着調用_evaluate_run
。
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
all_hooks, output_dir):
"""Run evaluation."""
eval_results = evaluation._evaluate_once( # pylint: disable=protected-access
checkpoint_path=checkpoint_path,
master=self._config.evaluation_master,
scaffold=scaffold,
eval_ops=update_op,
final_ops=eval_dict,
hooks=all_hooks,
config=self._session_config)
...
def _evaluate_once(checkpoint_path,
master='',
scaffold=None,
eval_ops=None,
feed_dict=None,
final_ops=None,
final_ops_feed_dict=None,
hooks=None,
config=None):
# 准備eval_ops
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
elif isinstance(eval_ops, (tuple, list)):
eval_ops = list(eval_ops) + [update_eval_step]
else:
eval_ops = [eval_ops, update_eval_step]
eval_step_value = _get_latest_eval_step_value(eval_ops)
# Prepare the session creator.
session_creator = monitored_session.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_filename_with_path=checkpoint_path,
master=master,
config=config)
with monitored_session.MonitoredSession(
session_creator=session_creator, hooks=hooks) as session:
if eval_ops is not None:
while not session.should_stop():
session.run(eval_ops, feed_dict)
_evaluate_once
執行最終的評估邏輯,先准備好評估用的ops,然后通過MonitoredSession執行評估的loop。
預測
預測的接口和實現如下,相對最為簡單。
def predict(self,
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None,
yield_single_examples=True):
with ops.Graph().as_default() as g:
# 從`input_fn`獲取預測用的特征。
features, input_hooks = self._get_features_from_input_fn(
input_fn, ModeKeys.PREDICT)
estimator_spec = self._call_model_fn(
features, None, ModeKeys.PREDICT, self.config)
predictions = self._extract_keys(
estimator_spec.predictions, predict_keys)
with training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
master=self._config.master,
scaffold=estimator_spec.scaffold,
config=self._session_config),
hooks=all_hooks) as mon_sess:
while not mon_sess.should_stop():
preds_evaluated = mon_sess.run(predictions)
導出模型
Estimator最后一個重要接口為導出模型接口,
def export_saved_model(
self, export_dir_base, serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
experimental_mode=ModeKeys.PREDICT):
input_receiver_fn_map = {experimental_mode: serving_input_receiver_fn}
return self._export_all_saved_models(
export_dir_base,
input_receiver_fn_map,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path,
strip_default_attrs=True)
def _export_all_saved_models(
self, export_dir_base, input_receiver_fn_map,
assets_extra=None, as_text=False, checkpoint_path=None,
strip_default_attrs=True):
with context.graph_mode():
builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
if input_receiver_fn_map.get(ModeKeys.PREDICT):
self._add_meta_graph_for_mode(
builder, input_receiver_fn_map, checkpoint_path,
save_variables, mode=ModeKeys.PREDICT,
strip_default_attrs=strip_default_attrs)
builder.save(as_text)
內置Estimator
我們看一下LinearClassifierV2的實現
class LinearClassifierV2(estimator.EstimatorV2):
def __init__(self,
feature_columns,
model_dir=None,
n_classes=2,
weight_column=None,
label_vocabulary=None,
optimizer='Ftrl',
config=None,
warm_start_from=None,
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
sparse_combiner='sum'):
head = head_utils.binary_or_multi_class_head(
n_classes, weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
"""Call the defined shared _linear_model_fn."""
return _linear_model_fn_v2(
features=features,
labels=labels,
mode=mode,
head=head,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
config=config,
sparse_combiner=sparse_combiner)
super(LinearClassifierV2, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from)
可以看到內置Estimator的實現和自定義Estimator的實現沒什么區別,也是通過實現model_fn並創建Estimator實例得到的。