今天跑wav2vec的預訓練模型:
import torch from fairseq.models.wav2vec import Wav2VecModel import librosa cp = torch.load('../models/wav2vec_large.pt') model = Wav2VecModel.build_model(cp['args'], task=None) model.load_state_dict(cp['model']) signal, sr = librosa.load('../static/test.wav') tensors = torch.from_numpy(signal).unsqueeze(0) z = model.feature_extractor(tensors) c = model.feature_aggregator(z) #print('c:', c) print(c.shape)
但是遇到一個非常惡心的問題,截圖如下:
分明是按照網上的代碼一步一步來的,就是報錯,困擾的很長時間,最后發現是fairseq安裝的版本不對。最開始的時候安裝的版本是1.0,但是fairseq是一個更新非常快的庫,但是代碼中加載的模型已經提出來有一段時間了,所以會出現參數不匹配的問題,將fairseq版本改為0.9.0版本就可以運行出來了。
完整的代碼見github:https://github.com/SolbiatiAlessandro/wav2vec.git
如果變換版本之后還是不行的話,建議參考https://blog.csdn.net/starinline/article/details/109944198這篇博客,里邊的博主也遇到了相同的問題,但是他改變的是hydra/_internal/utils.py中的參數,細節請到該博客閱讀
本人在嘗試了上邊的方法之后,問題仍然沒有解決,所以我建議,大家如果也遇到了相同的問題,先嘗試一下上邊博主的方法,如果嘗試無果的話,再嘗試更換一個fairseq的版本。