摘要:WeNet是一款開源端到端ASR工具包,它與ESPnet等開源語音項目相比,最大的優勢在於提供了從訓練到部署的一整套工具鏈,使ASR服務的工業落地更加簡單。
本文分享自華為雲社區《WeNet雲端推理部署代碼解析》,作者:xiaoye0829 。
WeNet是一款開源端到端ASR工具包,它與ESPnet等開源語音項目相比,最大的優勢在於提供了從訓練到部署的一整套工具鏈,使ASR服務的工業落地更加簡單。如圖1所示,WeNet工具包完全依賴於PyTorch生態:使用TorchScript進行模型開發,使用Torchaudio進行動態特征提取,使用DistributedDataParallel進行分布式訓練,使用torch JIT(Just In Time)進行模型導出,使用LibTorch作為生產環境運行時。本系列將對WeNet雲端推理部署代碼進行解析。
圖1:WeNet系統設計[1]
1. 代碼結構
WeNet雲端推理和部署代碼位於wenet/runtime/server/x86路徑下,編程語言為C++,其結構如下所示:
其中:
- 語音文件讀入與特征提取相關代碼位於frontend文件夾下;
- 端到端模型導入、端點檢測與語音解碼識別相關代碼位於decoder文件夾下,WeNet支持CTC prefix beam search和融合了WFST的CTC beam search這兩種解碼算法,后者的實現大量借鑒了Kaldi,相關代碼放在kaldi文件夾下;
- 在服務化方面,WeNet分別實現了基於WebSocket和基於gRPC的兩套服務端與客戶端,基於WebSocket的實現位於websocket文件夾下,基於gRPC的實現位於grpc文件夾下,兩種實現的入口main函數代碼都位於bin文件夾下。
- 日志、計時、字符串處理等輔助代碼位於utils文件夾下。
WeNet提供了CMakeLists.txt和Dockerfile,使得用戶能方便地進行項目編譯和鏡像構建。
2. 前端:frontend文件夾
1)語音文件讀入
WeNet只支持44字節header的wav格式音頻數據,wav header定義在WavHeader結構體中,包括音頻格式、聲道數、采樣率等音頻元信息。WavReader類用於語音文件讀入,調用fopen打開語音文件后,WavReader先讀入WavHeader大小的數據(也就是44字節),再根據WavHeader中的元信息確定待讀入音頻數據的大小,最后調用fread把音頻數據讀入buffer,並通過static_cast把數據轉化為float類型。
struct WavHeader { char riff[4]; // "riff" unsigned int size; char wav[4]; // "WAVE" char fmt[4]; // "fmt " unsigned int fmt_size; uint16_t format; uint16_t channels; unsigned int sample_rate; unsigned int bytes_per_second; uint16_t block_size; uint16_t bit; char data[4]; // "data" unsigned int data_size; };
這里存在的一個風險是,如果WavHeader中存放的元信息有誤,則會影響到語音數據的正確讀入。
2)特征提取
WeNet使用的特征是fbank,通過FeaturePipelineConfig結構體進行特征設置。默認幀長為25ms,幀移為10ms,采樣率和fbank維數則由用戶輸入。
用於特征提取的類是FeaturePipeline。為了同時支持流式與非流式語音識別,FeaturePipeline類中設置了input_finished_屬性來標志輸入是否結束,並通過set_input_finished()成員函數來對input_finished_屬性進行操作。
提取出來的fbank特征放在feature_queue_中,feature_queue_的類型是BlockingQueue<std::vector<float>>。BlockingQueue類是WeNet實現的一個阻塞隊列,初始化的時候需要提供隊列的容量(capacity),通過Push()函數向隊列中增加特征,通過Pop()函數從隊列中讀取特征:
- 當feature_queue_中的feature數量超過capacity,則Push線程被掛起,等待feature_queue_.Pop()釋放出空間。
- 當feature_queue_為空,則Pop線程被掛起,等待feature_queue_.Push()。
線程的掛起和恢復是通過C++標准庫中的線程同步原語std::mutex、std::condition_variable等實現。
線程同步還用在AcceptWaveform和ReadOne兩個成員函數中,AcceptWaveform把語音數據提取得到的fbank特征放到feature_queue_中,ReadOne成員函數則把特征從feature_queue_中讀出,是經典的生產者消費者模式。
3. 解碼器:decoder文件夾
1)TorchAsrModel
通過torch::jit::load對存在磁盤上的模型進行反序列化,得到一個ScriptModule對象。
torch::jit::script::Module model = torch::jit::load(model_path);
2)SearchInterface
WeNet推理支持的解碼方式都繼承自基類SearchInterface,如果要新增解碼算法,則需繼承SearchInterface類,並提供該類中所有純虛函數的實現,包括:
// 解碼算法的具體實現 virtual void Search(const torch::Tensor& logp) = 0; // 重置解碼過程 virtual void Reset() = 0; // 結束解碼過程 virtual void FinalizeSearch() = 0; // 解碼算法類型,返回一個枚舉常量SearchType virtual SearchType Type() const = 0; // 返回解碼輸入 virtual const std::vector<std::vector<int>>& Inputs() const = 0; // 返回解碼輸出 virtual const std::vector<std::vector<int>>& Outputs() const = 0; // 返回解碼輸出對應的似然值 virtual const std::vector<float>& Likelihood() const = 0; // 返回解碼輸出對應的次數 virtual const std::vector<std::vector<int>>& Times() const = 0;
目前WeNet只提供了SearchInterface的兩種子類實現,也即兩種解碼算法,分別定義在CtcPrefixBeamSearch和CtcWfstBeamSearch兩個類中。
3)CtcEndpoint
WeNet支持語音端點檢測,提供了一種基於規則的實現方式,用戶可以通過CtcEndpointConfig結構體和CtcEndpointRule結構體進行規則配置。WeNet默認的規則有三條:
- 檢測到了5s的靜音,則認為檢測到端點;
- 解碼出了任意時長的語音后,檢測到了1s的靜音,則認為檢測到端點;
- 解碼出了20s的語音,則認為檢測到端點。
一旦檢測到端點,則結束解碼。另外,WeNet把解碼得到的空白符(blank)視作靜音。
4)TorchAsrDecoder
WeNet提供的解碼器定義在TorchAsrDecoder類中。如圖3所示,WeNet支持雙向解碼,即疊加從左往右解碼和從右往左解碼的結果。在CTC beam search之后,用戶還可以選擇進行attention重打分。
圖2:WeNet解碼計算流程[2]
可以通過DecodeOptions結構體進行解碼參數配置,包括如下參數:
struct DecodeOptions { int chunk_size = 16; int num_left_chunks = -1; float ctc_weight = 0.0; float rescoring_weight = 1.0; float reverse_weight = 0.0; CtcEndpointConfig ctc_endpoint_config; CtcPrefixBeamSearchOptions ctc_prefix_search_opts; CtcWfstBeamSearchOptions ctc_wfst_search_opts; };
其中,ctc_weight表示CTC解碼權重,rescoring_weight表示重打分權重,reverse_weight表示從右往左解碼權重。最終解碼打分的計算方式為:
final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score; rescoring_score = left_to_right_score * (1 - reverse_weight) + right_to_left_score * reverse_weight
TorchAsrDecoder對外提供的解碼接口是Decode(),重打分接口是Rescoring()。Decode()返回的是枚舉類型DecodeState,包括三個枚舉常量:kEndBatch,kEndpoint和kEndFeats,分別表示當前批數據解碼結束、檢測到端點、所有特征解碼結束。
為了支持長語音識別,WeNet還提供了連續解碼接口ResetContinuousDecoding(),它與解碼器重置接口Reset()的區別在於:連續解碼接口會記錄全局已經解碼的語音幀數,並保留當前feature_pipeline_的狀態。
由於流式ASR服務需要在客戶端和服務端之間進行雙向的流式數據傳輸,WeNet實現了兩種支持雙向流式通信的服務化接口,分別基於WebSocket和gRPC。
4. 基於WebSocket
1)WebSocket簡介
WebSocket是基於TCP的一種新的網絡協議,與HTTP協議不同,WebSocket允許服務器主動發送信息給客戶端。 在連接建立后,客戶端和服務端可以連續互相發送數據,而無需在每次發送數據時重新發起連接請求。因此大大減小了網絡帶寬的資源消耗 ,在性能上更有優勢。
WebSocket支持文本和二進制兩種格式的數據傳輸 。
2)WeNet的WebSocket接口
WeNet使用了boost庫的WebSocket實現,定義了WebSocketClient(客戶端)和WebSocketServer(服務端)兩個類。
在流式ASR過程中,WebSocketClient給WebSocketServer發送數據可以分為三個步驟:1)發送開始信號與解碼配置;2)發送二進制語音數據:pcm字節流;3)發送停止信號。從WebSocketClient::SendStartSignal()和WebSocketClient::SendEndSignal()可以看到,開始信號、解碼配置和停止信號都是包裝在json字符串中,通過WebSocket文本格式傳輸。pcm字節流則通過WebSocket二進制格式進行傳輸。
void WebSocketClient::SendStartSignal() { // TODO(Binbin Zhang): Add sample rate and other setting surpport json::value start_tag = {{"signal", "start"}, {"nbest", nbest_}, {"continuous_decoding", continuous_decoding_}}; std::string start_message = json::serialize(start_tag); this->SendTextData(start_message); } void WebSocketClient::SendEndSignal() { json::value end_tag = {{"signal", "end"}}; std::string end_message = json::serialize(end_tag); this->SendTextData(end_message); }
WebSocketServer在收到數據后,需要先判斷收到的數據是文本還是二進制格式:如果是文本數據,則進行json解析,並根據解析結果進行解碼配置、啟動或停止,處理邏輯定義在ConnectionHandler::OnText()函數中。如果是二進制數據,則進行語音識別,處理邏輯定義在ConnectionHandler::OnSpeechData()中。
3)缺點
WebSocket需要開發者在WebSocketClient和WebSocketServer寫好對應的消息構造和解析代碼,容易出錯。另外,從以上代碼來看,服務需要借助json格式來序列化和反序列化數據,效率沒有protobuf格式高。
對於這些缺點,gRPC框架提供了更好的解決方法。
5. 基於gRPC
1)gRPC簡介
gRPC是谷歌推出的開源RPC框架,使用HTTP2作為網絡傳輸協議,並使用protobuf作為數據交換格式,有更高的數據傳輸效率。在gRPC框架下,開發者只需通過一個.proto文件定義好RPC服務(service)與消息(message),便可通過gRPC提供的代碼生成工具(protoc compiler)自動生成消息構造和解析代碼,使開發者能更好地聚焦於接口設計本身。
進行RPC調用時,gRPC Stub(客戶端)向gRPC Server(服務端)發送.proto文件中定義的Request消息,gRPC Server在處理完請求之后,通過.proto文件中定義的Response消息將結果返回給gRPC Stub。
gRPC具有跨語言特性,支持不同語言寫的微服務進行互動,比如說服務端用C++實現,客戶端用Ruby實現。protoc compiler支持12種語言的代碼生成。
圖1:gRPC Server和gRPC Stub交互[1]
2)WeNet的proto文件
WeNet定義的服務為ASR,包含一個Recognize方法,該方法的輸入(Request)、輸出(Response)都是流式數據(stream)。在使用protoc compiler編譯proto文件后,會得到4個文件:wenet.grpc.pb.h,http://wenet.grpc.pb.cc,wenet.pb.h,http://wenet.pb.cc。其中,wenet.pb.h/cc中存儲了protobuf數據格式的定義,wenet.grpc.pb.h中存儲了gRPC服務端/客戶端的定義。通過在代碼中包括wenet.pb.h和wenet.grpc.pb.h兩個頭文件,開發者可以直接使用Request消息和Response消息類,訪問其字段。
service ASR { rpc Recognize (stream Request) returns (stream Response) {} } message Request { message DecodeConfig { int32 nbest_config = 1; bool continuous_decoding_config = 2; } oneof RequestPayload { DecodeConfig decode_config = 1; bytes audio_data = 2; } } message Response { message OneBest { string sentence = 1; repeated OnePiece wordpieces = 2; } message OnePiece { string word = 1; int32 start = 2; int32 end = 3; } enum Status { ok = 0; failed = 1; } enum Type { server_ready = 0; partial_result = 1; final_result = 2; speech_end = 3; } Status status = 1; Type type = 2; repeated OneBest nbest = 3; }
3)WeNet的gRPC實現
WeNet gRPC服務端定義了GrpcServer類,該類繼承自wenet.grpc.pb.h中的純虛基類ASR::Service。
語音識別的入口函數是GrpcServer::Recognize,該函數初始化一個GRPCConnectionHandler實例來進行語音識別,並通過ServerReaderWriter類的stream對象來傳遞輸入輸出。
Status GrpcServer::Recognize(ServerContext* context, ServerReaderWriter<Response, Request>* stream) { LOG(INFO) << "Get Recognize request" << std::endl; auto request = std::make_shared<Request>(); auto response = std::make_shared<Response>(); GrpcConnectionHandler handler(stream, request, response, feature_config_, decode_config_, symbol_table_, model_, fst_); std::thread t(std::move(handler)); t.join(); return Status::OK; }
WeNet gRPC客戶端定義了GrpcClient類。客戶端在建立與服務端的連接時需實例化ASR::Stub,並通過ClientReaderWriter類的stream對象,實現雙向流式通信。
void GrpcClient::Connect() { channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_), grpc::InsecureChannelCredentials()); stub_ = ASR::NewStub(channel_); context_ = std::make_shared<ClientContext>(); stream_ = stub_->Recognize(context_.get()); request_ = std::make_shared<Request>(); response_ = std::make_shared<Response>(); request_->mutable_decode_config()->set_nbest_config(nbest_); request_->mutable_decode_config()->set_continuous_decoding_config( continuous_decoding_); stream_->Write(*request_); }
http://grpc_client_main.cc中,客戶端分段傳輸語音數據,每0.5s進行一次傳輸,即對於一個采樣率為8k的語音文件來說,每次傳4000幀數據。為了減小傳輸數據的大小,提升數據傳輸速度,先在客戶端將float類型轉為int16_t,服務端在接受到數據后,再將int16_t轉為float。c++中float為32位。
int main(int argc, char *argv[]) { ... // Send data every 0.5 second const float interval = 0.5; const int sample_interval = interval * sample_rate; for (int start = 0; start < num_sample; start += sample_interval) { if (client.done()) { break; } int end = std::min(start + sample_interval, num_sample); // Convert to short std::vector<int16_t> data; data.reserve(end - start); for (int j = start; j < end; j++) { data.push_back(static_cast<int16_t>(pcm_data[j])); } // Send PCM data client.SendBinaryData(data.data(), data.size() * sizeof(int16_t)); ... }
總結
本文主要對WeNet雲端部署代碼進行解析,介紹了WeNet基於WebSocket和基於gRPC的兩種服務化接口。
WeNet代碼結構清晰,簡潔易用,為語音識別提供了從訓練到部署的一套端到端解決方案,大大促進了工業落地效率,是非常值得借鑒學習的語音開源項目。
參考
[1] https://grpc.io/docs/what-is-grpc/introduction/
[2]WeNet: Production First and Production Ready End-to-End Speech Recognition Toolkit
[3]WeNet源碼
[4]WeNet: Production First and Production Ready End-to-End Speech Recognition Toolkit
[5] U2++: Unified Two-pass Bidirectional End-to-end Model for Speech Recognition