解析WeNet雲端推理部署代碼


摘要:WeNet是一款開源端到端ASR工具包,它與ESPnet等開源語音項目相比,最大的優勢在於提供了從訓練到部署的一整套工具鏈,使ASR服務的工業落地更加簡單。

本文分享自華為雲社區《WeNet雲端推理部署代碼解析》,作者:xiaoye0829 。

WeNet是一款開源端到端ASR工具包,它與ESPnet等開源語音項目相比,最大的優勢在於提供了從訓練到部署的一整套工具鏈,使ASR服務的工業落地更加簡單。如圖1所示,WeNet工具包完全依賴於PyTorch生態:使用TorchScript進行模型開發,使用Torchaudio進行動態特征提取,使用DistributedDataParallel進行分布式訓練,使用torch JIT(Just In Time)進行模型導出,使用LibTorch作為生產環境運行時。本系列將對WeNet雲端推理部署代碼進行解析。

圖1:WeNet系統設計[1]

1. 代碼結構

WeNet雲端推理和部署代碼位於wenet/runtime/server/x86路徑下,編程語言為C++,其結構如下所示:

其中:

  • 語音文件讀入與特征提取相關代碼位於frontend文件夾下;
  • 端到端模型導入、端點檢測與語音解碼識別相關代碼位於decoder文件夾下,WeNet支持CTC prefix beam search和融合了WFST的CTC beam search這兩種解碼算法,后者的實現大量借鑒了Kaldi,相關代碼放在kaldi文件夾下;
  • 在服務化方面,WeNet分別實現了基於WebSocket和基於gRPC的兩套服務端與客戶端,基於WebSocket的實現位於websocket文件夾下,基於gRPC的實現位於grpc文件夾下,兩種實現的入口main函數代碼都位於bin文件夾下。
  • 日志、計時、字符串處理等輔助代碼位於utils文件夾下。

WeNet提供了CMakeLists.txt和Dockerfile,使得用戶能方便地進行項目編譯和鏡像構建。

2. 前端:frontend文件夾

1)語音文件讀入

WeNet只支持44字節header的wav格式音頻數據,wav header定義在WavHeader結構體中,包括音頻格式、聲道數、采樣率等音頻元信息。WavReader類用於語音文件讀入,調用fopen打開語音文件后,WavReader先讀入WavHeader大小的數據(也就是44字節),再根據WavHeader中的元信息確定待讀入音頻數據的大小,最后調用fread把音頻數據讀入buffer,並通過static_cast把數據轉化為float類型。

struct WavHeader {
  char riff[4];  // "riff"
  unsigned int size;
  char wav[4];  // "WAVE"
  char fmt[4];  // "fmt "
  unsigned int fmt_size;
  uint16_t format;
  uint16_t channels;
  unsigned int sample_rate;
  unsigned int bytes_per_second;
  uint16_t block_size;
  uint16_t bit;
  char data[4];  // "data"
  unsigned int data_size;
};

這里存在的一個風險是,如果WavHeader中存放的元信息有誤,則會影響到語音數據的正確讀入。

2)特征提取

WeNet使用的特征是fbank,通過FeaturePipelineConfig結構體進行特征設置。默認幀長為25ms,幀移為10ms,采樣率和fbank維數則由用戶輸入。

用於特征提取的類是FeaturePipeline。為了同時支持流式與非流式語音識別,FeaturePipeline類中設置了input_finished_屬性來標志輸入是否結束,並通過set_input_finished()成員函數來對input_finished_屬性進行操作。

提取出來的fbank特征放在feature_queue_中,feature_queue_的類型是BlockingQueue<std::vector<float>>。BlockingQueue類是WeNet實現的一個阻塞隊列,初始化的時候需要提供隊列的容量(capacity),通過Push()函數向隊列中增加特征,通過Pop()函數從隊列中讀取特征:

  • 當feature_queue_中的feature數量超過capacity,則Push線程被掛起,等待feature_queue_.Pop()釋放出空間。
  • 當feature_queue_為空,則Pop線程被掛起,等待feature_queue_.Push()。
    線程的掛起和恢復是通過C++標准庫中的線程同步原語std::mutex、std::condition_variable等實現。
    線程同步還用在AcceptWaveform和ReadOne兩個成員函數中,AcceptWaveform把語音數據提取得到的fbank特征放到feature_queue_中,ReadOne成員函數則把特征從feature_queue_中讀出,是經典的生產者消費者模式。

3. 解碼器:decoder文件夾

1)TorchAsrModel

通過torch::jit::load對存在磁盤上的模型進行反序列化,得到一個ScriptModule對象。

torch::jit::script::Module model = torch::jit::load(model_path);

2)SearchInterface

WeNet推理支持的解碼方式都繼承自基類SearchInterface,如果要新增解碼算法,則需繼承SearchInterface類,並提供該類中所有純虛函數的實現,包括:

// 解碼算法的具體實現
virtual void Search(const torch::Tensor& logp) = 0;
// 重置解碼過程
virtual void Reset() = 0;
// 結束解碼過程
virtual void FinalizeSearch() = 0;
// 解碼算法類型,返回一個枚舉常量SearchType
virtual SearchType Type() const = 0;
// 返回解碼輸入
virtual const std::vector<std::vector<int>>& Inputs() const = 0;
// 返回解碼輸出
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
// 返回解碼輸出對應的似然值
virtual const std::vector<float>& Likelihood() const = 0;
// 返回解碼輸出對應的次數
virtual const std::vector<std::vector<int>>& Times() const = 0;

目前WeNet只提供了SearchInterface的兩種子類實現,也即兩種解碼算法,分別定義在CtcPrefixBeamSearch和CtcWfstBeamSearch兩個類中。

3)CtcEndpoint

WeNet支持語音端點檢測,提供了一種基於規則的實現方式,用戶可以通過CtcEndpointConfig結構體和CtcEndpointRule結構體進行規則配置。WeNet默認的規則有三條:

  • 檢測到了5s的靜音,則認為檢測到端點;
  • 解碼出了任意時長的語音后,檢測到了1s的靜音,則認為檢測到端點;
  • 解碼出了20s的語音,則認為檢測到端點。
    一旦檢測到端點,則結束解碼。另外,WeNet把解碼得到的空白符(blank)視作靜音。

4)TorchAsrDecoder

WeNet提供的解碼器定義在TorchAsrDecoder類中。如圖3所示,WeNet支持雙向解碼,即疊加從左往右解碼和從右往左解碼的結果。在CTC beam search之后,用戶還可以選擇進行attention重打分。

圖2:WeNet解碼計算流程[2]

可以通過DecodeOptions結構體進行解碼參數配置,包括如下參數:

struct DecodeOptions {
  int chunk_size = 16;
  int num_left_chunks = -1;
  float ctc_weight = 0.0;
  float rescoring_weight = 1.0;
  float reverse_weight = 0.0;
  CtcEndpointConfig ctc_endpoint_config;
  CtcPrefixBeamSearchOptions ctc_prefix_search_opts;
  CtcWfstBeamSearchOptions ctc_wfst_search_opts;
};

其中,ctc_weight表示CTC解碼權重,rescoring_weight表示重打分權重,reverse_weight表示從右往左解碼權重。最終解碼打分的計算方式為:

final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;
rescoring_score = left_to_right_score * (1 - reverse_weight) +
right_to_left_score * reverse_weight

TorchAsrDecoder對外提供的解碼接口是Decode(),重打分接口是Rescoring()。Decode()返回的是枚舉類型DecodeState,包括三個枚舉常量:kEndBatch,kEndpoint和kEndFeats,分別表示當前批數據解碼結束、檢測到端點、所有特征解碼結束。

為了支持長語音識別,WeNet還提供了連續解碼接口ResetContinuousDecoding(),它與解碼器重置接口Reset()的區別在於:連續解碼接口會記錄全局已經解碼的語音幀數,並保留當前feature_pipeline_的狀態。

由於流式ASR服務需要在客戶端和服務端之間進行雙向的流式數據傳輸,WeNet實現了兩種支持雙向流式通信的服務化接口,分別基於WebSocket和gRPC。

4. 基於WebSocket

1)WebSocket簡介

WebSocket是基於TCP的一種新的網絡協議,與HTTP協議不同,WebSocket允許服務器主動發送信息給客戶端。 在連接建立后,客戶端和服務端可以連續互相發送數據,而無需在每次發送數據時重新發起連接請求。因此大大減小了網絡帶寬的資源消耗 ,在性能上更有優勢。

WebSocket支持文本和二進制兩種格式的數據傳輸 。

2)WeNet的WebSocket接口

WeNet使用了boost庫的WebSocket實現,定義了WebSocketClient(客戶端)和WebSocketServer(服務端)兩個類。

在流式ASR過程中,WebSocketClient給WebSocketServer發送數據可以分為三個步驟:1)發送開始信號與解碼配置;2)發送二進制語音數據:pcm字節流;3)發送停止信號。從WebSocketClient::SendStartSignal()和WebSocketClient::SendEndSignal()可以看到,開始信號、解碼配置和停止信號都是包裝在json字符串中,通過WebSocket文本格式傳輸。pcm字節流則通過WebSocket二進制格式進行傳輸。

void WebSocketClient::SendStartSignal() {
  // TODO(Binbin Zhang): Add sample rate and other setting surpport
  json::value start_tag = {{"signal", "start"},
                           {"nbest", nbest_},
                           {"continuous_decoding", continuous_decoding_}};
  std::string start_message = json::serialize(start_tag);
  this->SendTextData(start_message);
}

void WebSocketClient::SendEndSignal() {
  json::value end_tag = {{"signal", "end"}};
  std::string end_message = json::serialize(end_tag);
  this->SendTextData(end_message);
}

WebSocketServer在收到數據后,需要先判斷收到的數據是文本還是二進制格式:如果是文本數據,則進行json解析,並根據解析結果進行解碼配置、啟動或停止,處理邏輯定義在ConnectionHandler::OnText()函數中。如果是二進制數據,則進行語音識別,處理邏輯定義在ConnectionHandler::OnSpeechData()中。

3)缺點

WebSocket需要開發者在WebSocketClient和WebSocketServer寫好對應的消息構造和解析代碼,容易出錯。另外,從以上代碼來看,服務需要借助json格式來序列化和反序列化數據,效率沒有protobuf格式高。

對於這些缺點,gRPC框架提供了更好的解決方法。

5. 基於gRPC

1)gRPC簡介

gRPC是谷歌推出的開源RPC框架,使用HTTP2作為網絡傳輸協議,並使用protobuf作為數據交換格式,有更高的數據傳輸效率。在gRPC框架下,開發者只需通過一個.proto文件定義好RPC服務(service)與消息(message),便可通過gRPC提供的代碼生成工具(protoc compiler)自動生成消息構造和解析代碼,使開發者能更好地聚焦於接口設計本身。

進行RPC調用時,gRPC Stub(客戶端)向gRPC Server(服務端)發送.proto文件中定義的Request消息,gRPC Server在處理完請求之后,通過.proto文件中定義的Response消息將結果返回給gRPC Stub。

gRPC具有跨語言特性,支持不同語言寫的微服務進行互動,比如說服務端用C++實現,客戶端用Ruby實現。protoc compiler支持12種語言的代碼生成。

圖1:gRPC Server和gRPC Stub交互[1]

2)WeNet的proto文件

WeNet定義的服務為ASR,包含一個Recognize方法,該方法的輸入(Request)、輸出(Response)都是流式數據(stream)。在使用protoc compiler編譯proto文件后,會得到4個文件:wenet.grpc.pb.h,,wenet.pb.h,。其中,wenet.pb.h/cc中存儲了protobuf數據格式的定義,wenet.grpc.pb.h中存儲了gRPC服務端/客戶端的定義。通過在代碼中包括wenet.pb.h和wenet.grpc.pb.h兩個頭文件,開發者可以直接使用Request消息和Response消息類,訪問其字段。

service ASR {
  rpc Recognize (stream Request) returns (stream Response) {}
}

message Request {

  message DecodeConfig {
    int32 nbest_config = 1;
    bool continuous_decoding_config = 2;
  }

  oneof RequestPayload {
    DecodeConfig decode_config = 1;
    bytes audio_data = 2;
  }
}

message Response {

  message OneBest {
    string sentence = 1;
    repeated OnePiece wordpieces = 2;
  }

  message OnePiece {
    string word = 1;
    int32 start = 2;
    int32 end = 3;
  }

  enum Status {
    ok = 0;
    failed = 1;
  }

  enum Type {
    server_ready = 0;
    partial_result = 1;
    final_result = 2;
    speech_end = 3;
  }

  Status status = 1;
  Type type = 2;
  repeated OneBest nbest = 3;
}

3)WeNet的gRPC實現

WeNet gRPC服務端定義了GrpcServer類,該類繼承自wenet.grpc.pb.h中的純虛基類ASR::Service。

語音識別的入口函數是GrpcServer::Recognize,該函數初始化一個GRPCConnectionHandler實例來進行語音識別,並通過ServerReaderWriter類的stream對象來傳遞輸入輸出。

Status GrpcServer::Recognize(ServerContext* context,
                             ServerReaderWriter<Response, Request>* stream) {
  LOG(INFO) << "Get Recognize request" << std::endl;
  auto request = std::make_shared<Request>();
  auto response = std::make_shared<Response>();
  GrpcConnectionHandler handler(stream, request, response, feature_config_,
                                decode_config_, symbol_table_, model_, fst_);
  std::thread t(std::move(handler));
  t.join();
  return Status::OK;
}

WeNet gRPC客戶端定義了GrpcClient類。客戶端在建立與服務端的連接時需實例化ASR::Stub,並通過ClientReaderWriter類的stream對象,實現雙向流式通信。

void GrpcClient::Connect() {
  channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),
                                 grpc::InsecureChannelCredentials());
  stub_ = ASR::NewStub(channel_);
  context_ = std::make_shared<ClientContext>();
  stream_ = stub_->Recognize(context_.get());
  request_ = std::make_shared<Request>();
  response_ = std::make_shared<Response>();
  request_->mutable_decode_config()->set_nbest_config(nbest_);
  request_->mutable_decode_config()->set_continuous_decoding_config(
      continuous_decoding_);
  stream_->Write(*request_);
}

中,客戶端分段傳輸語音數據,每0.5s進行一次傳輸,即對於一個采樣率為8k的語音文件來說,每次傳4000幀數據。為了減小傳輸數據的大小,提升數據傳輸速度,先在客戶端將float類型轉為int16_t,服務端在接受到數據后,再將int16_t轉為float。c++中float為32位。

int main(int argc, char *argv[]) {
  ...
  // Send data every 0.5 second
  const float interval = 0.5;
  const int sample_interval = interval * sample_rate;
  for (int start = 0; start < num_sample; start += sample_interval) {
    if (client.done()) {
      break;
    }
    int end = std::min(start + sample_interval, num_sample);
    // Convert to short
    std::vector<int16_t> data;
    data.reserve(end - start);
    for (int j = start; j < end; j++) {
      data.push_back(static_cast<int16_t>(pcm_data[j]));
    }
    // Send PCM data
    client.SendBinaryData(data.data(), data.size() * sizeof(int16_t));
    ...
}

總結

本文主要對WeNet雲端部署代碼進行解析,介紹了WeNet基於WebSocket和基於gRPC的兩種服務化接口。

WeNet代碼結構清晰,簡潔易用,為語音識別提供了從訓練到部署的一套端到端解決方案,大大促進了工業落地效率,是非常值得借鑒學習的語音開源項目。

參考

[1] https://grpc.io/docs/what-is-grpc/introduction/

[2]WeNet: Production First and Production Ready End-to-End Speech Recognition Toolkit

[3]WeNet源碼

[4]WeNet: Production First and Production Ready End-to-End Speech Recognition Toolkit

[5] U2++: Unified Two-pass Bidirectional End-to-end Model for Speech Recognition

 

點擊關注,第一時間了解華為雲新鮮技術~


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM