Keras/Tensorflow訓練邏輯研究


Keras是什么,以及相關的基礎知識,這里就不做詳細介紹,請參考Keras學習站點http://keras-cn.readthedocs.io/en/latest/

 

Tensorflow作為backend時的訓練邏輯梳理,主要是結合項目,研究了下源代碼!

 

我們的項目是智能問答機器人,基於雙向RNN(准確的說是GRU)網絡,這里網絡結構,就不做介紹,只研究其中的訓練邏輯,我們的訓練是基於fit_generator,即基於生成器模型,節省內存,有助效率提升。

什么是生成器以及生成器的工作原理,這里不表,屬於python的基礎范疇。

 

1. Keras的訓練,是基於batch進行的,每一個batch訓練過程,進行一次loss和acc的調整

1.1 .主要核心代碼

A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py

1)里面的裝飾器函數generate_legacy_interface里面。這里涉及到fit_generator這個最為核心的入口函數的執行過程。

2)python里面裝飾器工作原理,非常類似java代碼里面的AOP切面編程邏輯,即在正常的業務邏輯執行前,將before或者after或者兩者都執行一下。

3)訓練函數原型及重要參數解釋

def fit_generator(self, generator, #生成器,一個yield的函數,迭代返回數據 steps_per_epoch, #一次訓練周期(具體epoch是什么含義,要理解清楚)里面進行多少次batch epochs=1, #設置進行幾次全數據集的訓練,每一次全數據集訓練過程被定義成一個epoch,其實這個是可以靈活應用的 verbose=1, #一個開關,打開時,打印清晰的訓練數據,即加載ProgbarLogger這個回調函數 callbacks=None, #設置業務需要的回調函數,我們的模型中添加了ModelCheckpoint這個回調函數 validation_data=None, #驗證用的數據源設置,evaluate_generator函數要用到這個數據源,我們的項目里面,這里也是一個生成器 validation_steps=None, #設置驗證多少次數據后取平均值作為此epoch訓練后的效果,val_loss,val_acc的值受這個參數直接影響 class_weight=None, #此參數以及后續參數,我們的項目采用的都是默認值,可以參考官方文檔了解細節 max_queue_size=10, workers=1, use_multiprocessing=False, initial_epoch=0)

 

B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py

1)這里重點有ModelCheckpoint這個回調函數,涉及到業務參數,其他回調都是keras框架默認行為。

2)callback這個類,其實是一個容器,具體表現為一個List,可以在git_generator運行時,基於該函數的入參,構建一個Callback的實例,即一個list里面裝入業務需要的callback實例,這里默認會有BaseLogger以及History這個callback,然后會判斷verbose為true時,會添加ProgbarLogger這個callback,除此之外,就是fit_generator函數入參callbacks傳入的參數。一般都會傳遞ModelCheckpoint這個。

3)在git_generator這個基於生成器模式訓練的過程中,每一個epoch結束(on_epoch_end)時,都要調用這個callback函數(ModelCheckpoint)進行模型數據寫文件的操作

 

2. Keras訓練時用到的幾個重要回調函數(主要工作在on_batch_end里面)

回調函數是基於抽象類Callback實現的。下面是Callback的成員函數,便於理解。

   def __init__(self):
        self.validation_data = None

    def set_params(self, params):
        self.params = params

    def set_model(self, model):
        self.model = model

    def on_epoch_begin(self, epoch, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_batch_begin(self, batch, logs=None):
        pass

    def on_batch_end(self, batch, logs=None):
        pass

    def on_train_begin(self, logs=None):
        pass

    def on_train_end(self, logs=None):
        pass

 

A. keras.callbacks.BaseLogger

統計該batch里面訓練的loss以及acc的值,計入totals,乘以batch_size后。

def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size

        for k, v in logs.items():
            if k in self.totals:
                self.totals[k] += v * batch_size
            else:
                self.totals[k] = v * batch_size

在BaseLogger這個類的on_epoch_end函數里,執行對這個epoch訓練數據的loss以及acc求平均值。

def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            for k in self.params['metrics']:
                if k in self.totals:
                    # Make value available to next callbacks.
                    logs[k] = self.totals[k] / self.seen

 

B. keras.callbacks.ModelCheckpoint

在on_epoch_end時會保存模型數據進入文件

def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn('Can save best model only with %s available, '
                                  'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print('Epoch %05d: %s improved from %0.5f to %0.5f,'
                                  ' saving model to %s'
                                  % (epoch, self.monitor, self.best,
                                     current, filepath))
                        self.best = current
                        if self.save_weights_only:
                            self.model.save_weights(filepath, overwrite=True)
                        else:
                            self.model.save(filepath, overwrite=True)
                    else:
                        if self.verbose > 0:
                            print('Epoch %05d: %s did not improve' %
                                  (epoch, self.monitor))
            else:
                if self.verbose > 0:
                    print('Epoch %05d: saving model to %s' % (epoch, filepath))
                if self.save_weights_only:
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.model.save(filepath, overwrite=True)

 

C.keras.callbacks.History

主要記錄每一次epoch訓練的結果,結果包含loss以及acc的值

 

D. keras.callbacks.ProgbarLogger

這個函數里面實現訓練中間狀態數據信息的輸出,主要涉及進度相關信息。

 

3. 具體訓練邏輯過程

A. 訓練函數分析

a. model.fit_generator 訓練入口函數(參考上面的函數原型定義), 我們項目中用tk_data_generator函數作為訓練數據提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3)         callbacks.on_epoch_begin(epoch)
4)         while steps_done < steps_per_epoch:
5)             generator_output = next(output_generator)       #生成器next函數取輸入數據進行訓練,每次取一個batch大小的量
6)             callbacks.on_batch_begin(batch_index, batch_logs)
7)             outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8)             callbacks.on_batch_end(batch_index, batch_logs)
            end of while steps_done < steps_per_epoch
            self.evaluate_generator(...)          #當一個epoch的最后一次batch執行完畢,執行一次訓練效果的評估
9)      callbacks.on_epoch_end(epoch, epoch_logs)          #在這個執行過程中實現模型數據的保存操作
      end of while epoch < epochs
10) callbacks.on_train_end()


b. 特別介紹下train_on_batch
   train_on_batch (keras中的trainning.py)
        |_self._standardize_user_data
        |_self._make_train_function
        |_self.train_function (tensorflow的函數)
                        |_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)

 

B訓練和驗證的對比

a. 在每一個epoch的最后一個迭代(最后一次batch)時,要進行此輪epoch的校驗(evaluate)

日志如下:

141/141 [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450


第一個141表示batch_index已經達到141,即steps_per_epoch參數規定的最后一步
第二個141表示steps_per_epoch,即一個epoch里面進行多少次batch處理
12228s 表示此batch處理結束所花費的時間
loss:此epoch里面的平均損失值
acc:此epoch里面的平均准確率   
val_loss:此epoch訓練完后進行的evaluate得到的損失值
val_acc:此epoch訓練完后進行的evaluate得到的正確率

 

b. 驗證邏輯,和訓練邏輯差不多,只是將validation_steps指定次數的test的值進行取平均值,得到validation_steps次test的均值作為本epoch訓練的最終效果

self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)

1) while steps_done < steps:
2)           generator_output = next(output_generator)
3)         outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)對上述while得到的每次outs進行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))

其中重點test_on_batch

test_on_batch(self, x, y, sample_weight=None)
         |_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
         |_self._make_test_function()
         |_self.test_function(ins)                    
                    |_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)

 

c. train和test的重要區別,應該體現在下面的兩個函數上

def _make_train_function(self):
        if not hasattr(self, 'train_function'):
            raise RuntimeError('You must compile your model before using it.')
        if self.train_function is None:
            inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
            if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                inputs += [K.learning_phase()]

            with K.name_scope('training'):
                with K.name_scope(self.optimizer.__class__.__name__):
                    training_updates = self.optimizer.get_updates(
                        params=self._collected_trainable_weights,
                        loss=self.total_loss)
                updates = self.updates + training_updates
                # Gets loss and metrics. Updates weights at each call.
                self.train_function = K.function(inputs,
                                                 [self.total_loss] + self.metrics_tensors,
                                                 updates=updates,
                                                 name='train_function',
                                                 **self._function_kwargs)
def _make_test_function(self):
        if not hasattr(self, 'test_function'):
            raise RuntimeError('You must compile your model before using it.')
        if self.test_function is None:
            inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
            if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                inputs += [K.learning_phase()]
            # Return loss and metrics, no gradient updates. # Does update the network states.
            self.test_function = K.function(inputs,
                                            [self.total_loss] + self.metrics_tensors,
                                            updates=self.state_updates,
                                            name='test_function',
                                            **self._function_kwargs)

經過前面的代碼邏輯梳理,可以看到不管是train的過程還是test的過程,最終底層都是調用Tensorflow的session.run方法進行loss和acc的獲取,細心的觀察,會發現兩個session.run函數的入參其實有點不同。

結合上面train和test的私有函數中標注紅色的注釋,以及用K.function生成函數的入參中,可以看出train和test的差異。

 

總結:

0. 訓練過程中,每次權重的更新都是在一個batch上進行一次,是基於batch量的數據為單位進行一次權重的更新

1. 基於生成器模型訓練數據,可以提升效率,降低對物理服務器性能,尤其是內存的要求

2. 訓練過程中,Callback函數執行了大量的工作,包括loss、acc值的記錄,以及訓練中間結果的日志反饋,最重要的是模型數據的輸出,也是通過callback的方式實現(ModelCheckpoint)

3. 訓練(train)和驗證(evaluate/validate)的邏輯近乎一樣,訓練要更新權重,但是驗證過程,僅僅更新網絡狀態,不涉及權重(loss以及acc參數)信息的更新

4. 代碼梳理過程中,得出結論,Keras對python編程基本功底要求還是有點高的,采用了推導式編程習慣,生成器,裝飾器,回調等編程思想,另外,對矩陣運算,例如numpy.dot以及numpy.multiply的數學邏輯都有一定要求,否則比較難看懂。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM