keras中的loss、optimizer、metrics


用keras搭好模型架構之后的下一步,就是執行編譯操作。在編譯時,經常需要指定三個參數

  • loss
  • optimizer
  • metrics

這三個參數有兩類選擇:

  • 使用字符串
  • 使用標識符,如keras.losses,keras.optimizers,metrics包下面的函數

例如:

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
              optimizer=sgd,
              metrics=['accuracy'])

因為有時可以使用字符串,有時可以使用標識符,令人很想知道背后是如何操作的。下面分別針對optimizer,loss,metrics三種對象的獲取進行研究。

optimizer

一個模型只能有一個optimizer,在執行編譯的時候只能指定一個optimizer。
在keras.optimizers.py中,有一個get函數,用於根據用戶傳進來的optimizer參數獲取優化器的實例:

def get(identifier):
    # 如果后端是tensorflow並且使用的是tensorflow自帶的優化器實例,可以直接使用tensorflow原生的優化器 
    if K.backend() == 'tensorflow':
        # Wrap TF optimizer instances
        if isinstance(identifier, tf.train.Optimizer):
            return TFOptimizer(identifier)
    # 如果以json串的形式定義optimizer並進行參數配置
    if isinstance(identifier, dict):
        return deserialize(identifier)
    elif isinstance(identifier, six.string_types):
        # 如果以字符串形式指定optimizer,那么使用優化器的默認配置參數
        config = {'class_name': str(identifier), 'config': {}}
        return deserialize(config)
    if isinstance(identifier, Optimizer):
        # 如果使用keras封裝的Optimizer的實例
        return identifier
    else:
        raise ValueError('Could not interpret optimizer identifier: ' +
                         str(identifier))

其中,deserilize(config)函數的作用就是把optimizer反序列化制造一個實例。

loss

keras.losses函數也有一個get(identifier)方法。其中需要注意以下一點:

如果identifier是可調用的一個函數名,也就是一個自定義的損失函數,這個損失函數返回值是一個張量。這樣就輕而易舉的實現了自定義損失函數。除了使用str和dict類型的identifier,我們也可以直接使用keras.losses包下面的損失函數。

def get(identifier):
    if identifier is None:
        return None
    if isinstance(identifier, six.string_types):
        identifier = str(identifier)
        return deserialize(identifier)
    if isinstance(identifier, dict):
        return deserialize(identifier)
    elif callable(identifier):
        return identifier
    else:
        raise ValueError('Could not interpret '
                         'loss function identifier:', identifier)

metrics

在model.compile()函數中,optimizer和loss都是單數形式,只有metrics是復數形式。因為一個模型只能指明一個optimizer和loss,卻可以指明多個metrics。metrics也是三者中處理邏輯最為復雜的一個。

在keras最核心的地方keras.engine.train.py中有如下處理metrics的函數。這個函數其實就做了兩件事:

  • 根據輸入的metric找到具體的metric對應的函數
  • 計算metric張量

在尋找metric對應函數時,有兩種步驟:

  • 使用字符串形式指明准確率和交叉熵
  • 使用keras.metrics.py中的函數
def handle_metrics(metrics, weights=None):
    metric_name_prefix = 'weighted_' if weights is not None else ''

    for metric in metrics:
        # 如果metrics是最常見的那種:accuracy,交叉熵
        if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
            # custom handling of accuracy/crossentropy
            # (because of class mode duality)
            output_shape = K.int_shape(self.outputs[i])
            # 如果輸出維度是1或者損失函數是二分類損失函數,那么說明是個二分類問題,應該使用二分類的accuracy和二分類的的交叉熵
            if (output_shape[-1] == 1 or
                self.loss_functions[i] == losses.binary_crossentropy):
                # case: binary accuracy/crossentropy
                if metric in ('accuracy', 'acc'):
                    metric_fn = metrics_module.binary_accuracy
                elif metric in ('crossentropy', 'ce'):
                    metric_fn = metrics_module.binary_crossentropy
            # 如果損失函數是sparse_categorical_crossentropy,那么目標y_input就不是one-hot的,所以就需要使用sparse的多類准去率和sparse的多類交叉熵
            elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
                # case: categorical accuracy/crossentropy
                # with sparse targets
                if metric in ('accuracy', 'acc'):
                    metric_fn = metrics_module.sparse_categorical_accuracy
                elif metric in ('crossentropy', 'ce'):
                    metric_fn = metrics_module.sparse_categorical_crossentropy
            else:
                # case: categorical accuracy/crossentropy
                if metric in ('accuracy', 'acc'):
                    metric_fn = metrics_module.categorical_accuracy
                elif metric in ('crossentropy', 'ce'):
                    metric_fn = metrics_module.categorical_crossentropy
            if metric in ('accuracy', 'acc'):
                    suffix = 'acc'
            elif metric in ('crossentropy', 'ce'):
                    suffix = 'ce'
            weighted_metric_fn = weighted_masked_objective(metric_fn)
            metric_name = metric_name_prefix + suffix
        else:
            # 如果輸入的metric不是字符串,那么就調用metrics模塊獲取
            metric_fn = metrics_module.get(metric)
            weighted_metric_fn = weighted_masked_objective(metric_fn)
            # Get metric name as string
            if hasattr(metric_fn, 'name'):
                metric_name = metric_fn.name
            else:
                metric_name = metric_fn.__name__
            metric_name = metric_name_prefix + metric_name

        with K.name_scope(metric_name):
            metric_result = weighted_metric_fn(y_true, y_pred,
                                                weights=weights,
                                                mask=masks[i])

        # Append to self.metrics_names, self.metric_tensors,
        # self.stateful_metric_names
        if len(self.output_names) > 1:
            metric_name = self.output_names[i] + '_' + metric_name
        # Dedupe name
        j = 1
        base_metric_name = metric_name
        while metric_name in self.metrics_names:
            metric_name = base_metric_name + '_' + str(j)
            j += 1
        self.metrics_names.append(metric_name)
        self.metrics_tensors.append(metric_result)

        # Keep track of state updates created by
        # stateful metrics (i.e. metrics layers).
        if isinstance(metric_fn, Layer) and metric_fn.stateful:
            self.stateful_metric_names.append(metric_name)
            self.stateful_metric_functions.append(metric_fn)
            self.metrics_updates += metric_fn.updates

無論怎么使用metric,最終都會變成metrics包下面的函數。當使用字符串形式指明accuracy和crossentropy時,keras會非常智能地確定應該使用metrics包下面的哪個函數。因為metrics包下的那些metric函數有不同的使用場景,例如:

  • 有的處理的是one-hot形式的y_input(數據的類別),有的處理的是非one-hot形式的y_input
  • 有的處理的是二分類問題的metric,有的處理的是多分類問題的metric

當使用字符串“accuracy”和“crossentropy”指明metric時,keras會根據損失函數、輸出層的shape來確定具體應該使用哪個metric函數。在任何情況下,直接使用metrics下面的函數名是總不會出錯的。

keras.metrics.py文件中也有一個get(identifier)函數用於獲取metric函數。

def get(identifier):
    if isinstance(identifier, dict):
        config = {'class_name': str(identifier), 'config': {}}
        return deserialize(config)
    elif isinstance(identifier, six.string_types):
        return deserialize(str(identifier))
    elif callable(identifier):
        return identifier
    else:
        raise ValueError('Could not interpret '
                         'metric function identifier:', identifier)

如果identifier是字符串或者字典,那么會根據identifier反序列化出一個metric函數。
如果identifier本身就是一個函數名,那么就直接返回這個函數名。這種方式就為自定義metric提供了巨大便利。

keras中的設計哲學堪稱完美。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM