CRF keras代碼實現


這份代碼來自於蘇劍林

 

# -*- coding:utf-8 -*-

from keras.layers import Layer
import keras.backend as K

class CRF(Layer):
    """純Keras實現CRF層
    CRF層本質上是一個帶訓練參數的loss計算層,因此CRF層只用來訓練模型,
    而預測則需要另外建立模型,但是還是要用到訓練好的轉移矩陣
    """
    def __init__(self, ignore_last_label=False, **kwargs):
        """ignore_last_label:定義要不要忽略最后一個標簽,起到mask的效果
        """
        self.ignore_last_label = 1 if ignore_last_label else 0
        super(CRF, self).__init__(**kwargs)
    def build(self, input_shape):
        self.num_labels = input_shape[-1] - self.ignore_last_label
        self.trans = self.add_weight(name='crf_trans',
                                     shape=(self.num_labels, self.num_labels),
                                     initializer='glorot_uniform',
                                     trainable=True)
    def log_norm_step(self, inputs, states):
        """遞歸計算歸一化因子
        要點:1、遞歸計算;2、用logsumexp避免溢出。
        技巧:通過expand_dims來對齊張量。
        """
        states = K.expand_dims(states[0], 2) # previous
        inputs = K.expand_dims(inputs, 2) # 這個時刻的對標簽的打分值,Emission score
        trans = K.expand_dims(self.trans, 0) # 轉移矩陣

        output = K.logsumexp(states+trans+inputs, 1) # e 指數求和,log是防止溢出
        return output, [output] 

    def path_score(self, inputs, labels):
        """計算目標路徑的相對概率(還沒有歸一化)
        要點:逐標簽得分,加上轉移概率得分。
        技巧:用“預測”點乘“目標”的方法抽取出目標路徑的得分。
        """
        # 在CRF中涉及到標簽得分加上轉移概率,而這個point score就是相當於是標簽得分(在真是標簽的情況下,查看預測對於真實標簽位置的總得分),因為labels的shape是[B, T, N],而在N這個維度是one-hot,
        # 這里再乘以pred,相當於是對labels存在1的地方進行打分,其余地方全為0,再進行第2個維度相加表示去除0的值,再相加表示求一個總的標簽得分
        point_score = K.sum(K.sum(inputs*labels, 2), 1, keepdims=True) # 逐標簽得分, shape [B, 1]
        labels1 = K.expand_dims(labels[:, :-1], 3) # shape [B, T-1, N, 1]
        labels2 = K.expand_dims(labels[:, 1:], 2) # shape [B, T-1, 1, N]
        # 這里相乘的目的相當於從上一時刻轉移到當前時刻,確定當前時刻是從上一時刻哪一個標簽轉移過來的,因為labels是one-hot的形式,所以在最后兩個維度只有1個元素為1,其他全部為0,表示轉移標志
        labels = labels1 * labels2 # 兩個錯位labels,負責從轉移矩陣中抽取目標轉移得分 shape [B, T-1, N, N]
        trans = K.expand_dims(K.expand_dims(self.trans, 0), 0)
        # K.sum(trans*labels, [2, 3]),因為trans*labels的結果是[B, T-1, N, N], 而后面兩個維度中只有1個有值,表示轉移得分
        trans_score = K.sum(K.sum(trans*labels, [2, 3]), 1, keepdims=True) # 求出所有T-1時刻的概率轉移總得分,K.sum(trans*labels, [2, 3]), 表示每個時刻的轉移得分
        return point_score+trans_score # 兩部分得分之和

    def call(self, inputs): # CRF本身不改變輸出,它只是一個loss
        return inputs

    def loss(self, y_true, y_pred): # 目標y_pred需要是one hot形式
        mask = 1-y_true[:, 1:, -1] if self.ignore_last_label else None
        y_true, y_pred = y_true[:, :, :self.num_labels], y_pred[:, :, :self.num_labels]
        init_states = [y_pred[:, 0]] # 初始狀態
        log_norm, _, _ = K.rnn(self.log_norm_step, y_pred[:, 1:], init_states, mask=mask) # 計算Z向量(對數) shape[batch_size, output_dim]
        log_norm = K.logsumexp(log_norm, 1, keepdims=True) # 計算Z(對數)shape [batch_size, 1] 計算一個總的
        path_score = self.path_score(y_pred, y_true) # 計算分子(對數)
        return log_norm - path_score # 即log(分子/分母)

    def accuracy(self, y_true, y_pred): # 訓練過程中顯示逐幀准確率的函數,排除了mask的影響
        mask = 1-y_true[:,:,-1] if self.ignore_last_label else None
        y_true,y_pred = y_true[:,:,:self.num_labels],y_pred[:,:,:self.num_labels]
        isequal = K.equal(K.argmax(y_true, 2), K.argmax(y_pred, 2))
        isequal = K.cast(isequal, 'float32')
        if mask == None:
            return K.mean(isequal)
        else:
            return K.sum(isequal*mask) / K.sum(mask)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM