GRU-CTC中文語音識別



該項目github地址

基於keras的中文語音識別

  • 該項目實現了GRU-CTC中文語音識別,所有代碼都在gru_ctc_am.py中,包括:
    • 音頻文件特征提取
    • 文本數據處理
    • 數據格式處理
    • 構建模型
    • 模型訓練及解碼
  • 之外還包括將aishell數據處理為thchs30數據格式,合並數據進行訓練。代碼及數據放在gen_aishell_data中。

默認數據集為thchs30,參考gen_aishell_data中的數據及代碼,也可以使用aishell的數據進行訓練。

音頻文件特征提取

# -----------------------------------------------------------------------------------------------------
'''
&usage:		[audio]對音頻文件進行處理,包括生成總的文件列表、特征提取等
'''
# -----------------------------------------------------------------------------------------------------
# 生成音頻列表
def genwavlist(wavpath):
	wavfiles = {}
	fileids = []
	for (dirpath, dirnames, filenames) in os.walk(wavpath):
		for filename in filenames:
			if filename.endswith('.wav'):
				filepath = os.sep.join([dirpath, filename])
				fileid = filename.strip('.wav')
				wavfiles[fileid] = filepath
				fileids.append(fileid)
	return wavfiles,fileids

# 對音頻文件提取mfcc特征
def compute_mfcc(file):
	fs, audio = wav.read(file)
	mfcc_feat = mfcc(audio, samplerate=fs, numcep=26)
	mfcc_feat = mfcc_feat[::3]
	mfcc_feat = np.transpose(mfcc_feat)  
	mfcc_feat = pad_sequences(mfcc_feat, maxlen=500, dtype='float', padding='post', truncating='post').T
	return mfcc_feat

文本數據處理

# -----------------------------------------------------------------------------------------------------
'''
&usage:		[text]對文本標注文件進行處理,包括生成拼音到數字的映射,以及將拼音標注轉化為數字的標注轉化
'''
# -----------------------------------------------------------------------------------------------------
# 利用訓練數據生成詞典
def gendict(textfile_path):
	dicts = []
	textfile = open(textfile_path,'r+')
	for content in textfile.readlines():
		content = content.strip('\n')
		content = content.split(' ',1)[1]
		content = content.split(' ')
		dicts += (word for word in content)
	counter = Counter(dicts)
	words = sorted(counter)
	wordsize = len(words)
	word2num = dict(zip(words, range(wordsize)))
	num2word = dict(zip(range(wordsize), words))
	return word2num, num2word #1176個音素

# 文本轉化為數字
def text2num(textfile_path):
	lexcion,num2word = gendict(textfile_path)
	word2num = lambda word:lexcion.get(word, 0)
	textfile = open(textfile_path, 'r+')
	content_dict = {}
	for content in textfile.readlines():
		content = content.strip('\n')
		cont_id = content.split(' ',1)[0]
		content = content.split(' ',1)[1]
		content = content.split(' ')
		content = list(map(word2num,content))
		add_num = list(np.zeros(50-len(content)))
		content = content + add_num
		content_dict[cont_id] = content
	return content_dict,lexcion

數據格式處理

# -----------------------------------------------------------------------------------------------------
'''
&usage:		[data]數據生成器構造,用於訓練的數據生成,包括輸入特征及標注的生成,以及將數據轉化為特定格式
'''
# -----------------------------------------------------------------------------------------------------
# 將數據格式整理為能夠被網絡所接受的格式,被data_generator調用
def get_batch(x, y, train=False, max_pred_len=50, input_length=500):
    X = np.expand_dims(x, axis=3)
    X = x # for model2
#     labels = np.ones((y.shape[0], max_pred_len)) *  -1 # 3 # , dtype=np.uint8
    labels = y
    input_length = np.ones([x.shape[0], 1]) * ( input_length - 2 )
#     label_length = np.ones([y.shape[0], 1])
    label_length = np.sum(labels > 0, axis=1)
    label_length = np.expand_dims(label_length,1)
    inputs = {'the_input': X,
              'the_labels': labels,
              'input_length': input_length,
              'label_length': label_length,
              }
    outputs = {'ctc': np.zeros([x.shape[0]])}  # dummy data for dummy loss function
    return (inputs, outputs)

# 數據生成器,默認音頻為thchs30\train,默認標注為thchs30\train.syllable,被模型訓練方法fit_generator調用
def data_generate(wavpath = 'E:\\Data\\data_thchs30\\train', textfile = 'E:\\Data\\thchs30\\train.syllable.txt', bath_size=4):
	wavdict,fileids = genwavlist(wavpath)
	#print(wavdict)
	content_dict,lexcion = text2num(textfile)
	genloop = len(fileids)//bath_size
	print("all loop :", genloop)
	while True:
		feats = []
		labels = []
		# 隨機選擇某個音頻文件作為訓練數據
		i = random.randint(0,genloop-1)
		for x in range(bath_size):
			num = i * bath_size + x
			fileid = fileids[num]
			# 提取音頻文件的特征
			mfcc_feat = compute_mfcc(wavdict[fileid])
			feats.append(mfcc_feat)
			# 提取標注對應的label值
			labels.append(content_dict[fileid])
		# 將數據格式修改為get_batch可以處理的格式
		feats = np.array(feats)
		labels = np.array(labels)
		# 調用get_batch將數據處理為訓練所需的格式
		inputs, outputs = get_batch(feats, labels)
		yield inputs, outputs

構建模型

# -----------------------------------------------------------------------------------------------------
'''
&usage:		[net model]構件網絡結構,用於最終的訓練和識別
'''
# -----------------------------------------------------------------------------------------------------
# 被creatModel調用,用作ctc損失的計算
def ctc_lambda(args):
	labels, y_pred, input_length, label_length = args
	y_pred = y_pred[:, :, :]
	return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

# 構建網絡結構,用於模型的訓練和識別
def creatModel():
	input_data = Input(name='the_input', shape=(500, 26))
	layer_h1 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(input_data)
	#layer_h1 = Dropout(0.3)(layer_h1)
	layer_h2 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h1)
	layer_h3_1 = GRU(512, return_sequences=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
	layer_h3_2 = GRU(512, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
	layer_h3 = add([layer_h3_1, layer_h3_2])
	layer_h4 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h3)
	#layer_h4 = Dropout(0.3)(layer_h4)
	layer_h5 = Dense(1177, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h4)
	output = Activation('softmax', name='Activation0')(layer_h5)
	model_data = Model(inputs=input_data, outputs=output)
	#ctc
	labels = Input(name='the_labels', shape=[50], dtype='float32')
	input_length = Input(name='input_length', shape=[1], dtype='int64')
	label_length = Input(name='label_length', shape=[1], dtype='int64')
	loss_out = Lambda(ctc_lambda, output_shape=(1,), name='ctc')([labels, output, input_length, label_length])
	model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)
	model.summary()
	ada_d = Adadelta(lr=0.01, rho=0.95, epsilon=1e-06)
	#model=multi_gpu_model(model,gpus=2)
	model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=ada_d)
	#test_func = K.function([input_data], [output])
	print("model compiled successful!")
	return model, model_data

模型訓練及解碼

# -----------------------------------------------------------------------------------------------------
'''
&usage:		模型的解碼,用於將數字信息映射為拼音
'''
# -----------------------------------------------------------------------------------------------------
# 對model預測出的softmax的矩陣,使用ctc的准則解碼,然后通過字典num2word轉為文字
def decode_ctc(num_result, num2word):
	result = num_result[:, :, :]
	in_len = np.zeros((1), dtype = np.int32)
	in_len[0] = 50;
	r = K.ctc_decode(result, in_len, greedy = True, beam_width=1, top_paths=1)
	r1 = K.get_value(r[0][0])
	r1 = r1[0]
	text = []
	for i in r1:
		text.append(num2word[i])
	return r1, text


# -----------------------------------------------------------------------------------------------------
'''
&usage:		模型的訓練
'''
# -----------------------------------------------------------------------------------------------------
# 訓練模型
def train():
	# 准備訓練所需數據
	yielddatas = data_generate()
	# 導入模型結構,訓練模型,保存模型參數
	model, model_data = creatModel()
	model.fit_generator(yielddatas, steps_per_epoch=2000, epochs=1)
	model.save_weights('model.mdl')
	model_data.save_weights('model_data.mdl')


# -----------------------------------------------------------------------------------------------------
'''
&usage:		模型的測試,看識別結果是否正確
'''
# -----------------------------------------------------------------------------------------------------
# 測試模型
def test():
	# 准備測試數據,以及生成字典
	word2num, num2word = gendict('E:\\Data\\thchs30\\train.syllable.txt')
	yielddatas = data_generate(bath_size=1)
	# 載入訓練好的模型,並進行識別
	model, model_data = creatModel()
	model_data.load_weights('model_data.mdl')
	result = model_data.predict_generator(yielddatas, steps=1)
	# 將數字結果轉化為文本結果
	result, text = decode_ctc(result, num2word)
	print('數字結果: ', result)
	print('文本結果:', text)

aishell數據轉化

將aishell中的漢字標注轉化為拼音標注,利用該數據與thchs30數據訓練同樣的網絡結構。

該模型作為一個練手小項目。
沒有使用語言模型,直接簡單建模。

我的github: https://github.com/audier


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM