用keras搭好模型架構之后的下一步,就是執行編譯操作。在編譯時,經常需要指定三個參數
- loss
- optimizer
- metrics
這三個參數有兩類選擇:
- 使用字符串
- 使用標識符,如keras.losses,keras.optimizers,metrics包下面的函數
例如:
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy'])
因為有時可以使用字符串,有時可以使用標識符,令人很想知道背后是如何操作的。下面分別針對optimizer,loss,metrics三種對象的獲取進行研究。
optimizer
一個模型只能有一個optimizer,在執行編譯的時候只能指定一個optimizer。
在keras.optimizers.py中,有一個get函數,用於根據用戶傳進來的optimizer參數獲取優化器的實例:
def get(identifier):
# 如果后端是tensorflow並且使用的是tensorflow自帶的優化器實例,可以直接使用tensorflow原生的優化器
if K.backend() == 'tensorflow':
# Wrap TF optimizer instances
if isinstance(identifier, tf.train.Optimizer):
return TFOptimizer(identifier)
# 如果以json串的形式定義optimizer並進行參數配置
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
# 如果以字符串形式指定optimizer,那么使用優化器的默認配置參數
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
if isinstance(identifier, Optimizer):
# 如果使用keras封裝的Optimizer的實例
return identifier
else:
raise ValueError('Could not interpret optimizer identifier: ' +
str(identifier))
其中,deserilize(config)函數的作用就是把optimizer反序列化制造一個實例。
loss
keras.losses函數也有一個get(identifier)方法。其中需要注意以下一點:
如果identifier是可調用的一個函數名,也就是一個自定義的損失函數,這個損失函數返回值是一個張量。這樣就輕而易舉的實現了自定義損失函數。除了使用str和dict類型的identifier,我們也可以直接使用keras.losses包下面的損失函數。
def get(identifier):
if identifier is None:
return None
if isinstance(identifier, six.string_types):
identifier = str(identifier)
return deserialize(identifier)
if isinstance(identifier, dict):
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'loss function identifier:', identifier)
metrics
在model.compile()函數中,optimizer和loss都是單數形式,只有metrics是復數形式。因為一個模型只能指明一個optimizer和loss,卻可以指明多個metrics。metrics也是三者中處理邏輯最為復雜的一個。
在keras最核心的地方keras.engine.train.py中有如下處理metrics的函數。這個函數其實就做了兩件事:
- 根據輸入的metric找到具體的metric對應的函數
- 計算metric張量
在尋找metric對應函數時,有兩種步驟:
- 使用字符串形式指明准確率和交叉熵
- 使用keras.metrics.py中的函數
def handle_metrics(metrics, weights=None):
metric_name_prefix = 'weighted_' if weights is not None else ''
for metric in metrics:
# 如果metrics是最常見的那種:accuracy,交叉熵
if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
# custom handling of accuracy/crossentropy
# (because of class mode duality)
output_shape = K.int_shape(self.outputs[i])
# 如果輸出維度是1或者損失函數是二分類損失函數,那么說明是個二分類問題,應該使用二分類的accuracy和二分類的的交叉熵
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
# case: binary accuracy/crossentropy
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.binary_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.binary_crossentropy
# 如果損失函數是sparse_categorical_crossentropy,那么目標y_input就不是one-hot的,所以就需要使用sparse的多類准去率和sparse的多類交叉熵
elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
# case: categorical accuracy/crossentropy
# with sparse targets
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.sparse_categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.sparse_categorical_crossentropy
else:
# case: categorical accuracy/crossentropy
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.categorical_crossentropy
if metric in ('accuracy', 'acc'):
suffix = 'acc'
elif metric in ('crossentropy', 'ce'):
suffix = 'ce'
weighted_metric_fn = weighted_masked_objective(metric_fn)
metric_name = metric_name_prefix + suffix
else:
# 如果輸入的metric不是字符串,那么就調用metrics模塊獲取
metric_fn = metrics_module.get(metric)
weighted_metric_fn = weighted_masked_objective(metric_fn)
# Get metric name as string
if hasattr(metric_fn, 'name'):
metric_name = metric_fn.name
else:
metric_name = metric_fn.__name__
metric_name = metric_name_prefix + metric_name
with K.name_scope(metric_name):
metric_result = weighted_metric_fn(y_true, y_pred,
weights=weights,
mask=masks[i])
# Append to self.metrics_names, self.metric_tensors,
# self.stateful_metric_names
if len(self.output_names) > 1:
metric_name = self.output_names[i] + '_' + metric_name
# Dedupe name
j = 1
base_metric_name = metric_name
while metric_name in self.metrics_names:
metric_name = base_metric_name + '_' + str(j)
j += 1
self.metrics_names.append(metric_name)
self.metrics_tensors.append(metric_result)
# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
if isinstance(metric_fn, Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates
無論怎么使用metric,最終都會變成metrics包下面的函數。當使用字符串形式指明accuracy和crossentropy時,keras會非常智能地確定應該使用metrics包下面的哪個函數。因為metrics包下的那些metric函數有不同的使用場景,例如:
- 有的處理的是one-hot形式的y_input(數據的類別),有的處理的是非one-hot形式的y_input
- 有的處理的是二分類問題的metric,有的處理的是多分類問題的metric
當使用字符串“accuracy”和“crossentropy”指明metric時,keras會根據損失函數、輸出層的shape來確定具體應該使用哪個metric函數。在任何情況下,直接使用metrics下面的函數名是總不會出錯的。
keras.metrics.py文件中也有一個get(identifier)函數用於獲取metric函數。
def get(identifier):
if isinstance(identifier, dict):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
elif isinstance(identifier, six.string_types):
return deserialize(str(identifier))
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'metric function identifier:', identifier)
如果identifier是字符串或者字典,那么會根據identifier反序列化出一個metric函數。
如果identifier本身就是一個函數名,那么就直接返回這個函數名。這種方式就為自定義metric提供了巨大便利。
keras中的設計哲學堪稱完美。