bert-as-service輸出分類結果


bert-as-service: Mapping a variable-length sentence to a fixed-length vector using BERT model

默認情況下bert-as-service只提供固定長度的特征向量,如果想要直接獲取分類預測結果呢?

bert提供了的run_classifier.py 以訓練分類模型,同時bert提供了離線評估的方法。

一些可能的部署思路

  • bert基於tensorflow實現,可以參考tensorflow-serving對外提供部署服務
  • 參考bert代碼修改離線接口為在線推斷,基於flask/django提供部署服務
  • 修改bert-as-service提供高效在線預測服務

bert-as-service的強大可以參考:Serving Google BERT in Production using Tensorflow and ZeroMQ

修改bert-as-service提供分類預測

思路:https://github.com/hanxiao/bert-as-service/issues/213

bert-as-service 默認情況下,不會加載分類層

  1. 加載模型的同時加載分類層的權重和bias
  2. 添加分類層

graph.py#L79中添加

            if args.pooling_strategy == PoolingStrategy.CLASSIFICATION:
                 hidden_size = 768
                 output_weights = tf.get_variable(
                     "output_weights", [args.num_labels, hidden_size],
                     )

                  output_bias = tf.get_variable(
                     "output_bias", [args.num_labels])

              tvars = tf.trainable_variables()		            

注意:在加載權重和bias的時候不要定義初始化方法,否則會從初始化方法進行加載,而不是微調模型。

graph.py#L127添加

                elif args.pooling_strategy == PoolingStrategy.CLASSIFICATION:
                     # pooled = tf.squeeze(encoder_layer[:, 0:1, :], axis=1)
                     logits = tf.matmul(pooled, output_weights, transpose_b=True)
                     logits = tf.nn.bias_add(logits, output_bias)
                     pooled = tf.nn.softmax(logits, axis=-1)

具體代碼github


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM