整個過程分為以下步驟完成:
- 語料准備
- 語料預處理
- 模型參數配置
- 構建模型
- 訓練模型
- 模型作詩
- 繪制模型網絡結構圖
下面一步步來構建和訓練一個會寫詩的模型。
第一,語料准備。一共四萬多首古詩,每行一首詩,標題在預處理的時候已經去掉了。
第二,文件預處理。首先,機器並不懂每個中文漢字代表的是什么,所以要將文字轉換為機器能理解的形式,這里我們采用 One-Hot 的形式,這樣詩句中的每個字都能用向量來表示,下面定義函數 preprocess_file()
來處理。
1 puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》'] 2 3 4 def preprocess_file(Config): 5 # 語料文本內容 6 files_content = '' 7 with open(Config.poetry_file, 'r', encoding='utf-8') as f: 8 for line in f: 9 # 每行的末尾加上"]"符號代表一首詩結束 10 for char in puncs: 11 line = line.replace(char, "") 12 files_content += line.strip() + "]" 13 14 words = sorted(list(files_content)) 15 words.remove(']') 16 counted_words = {} 17 for word in words: 18 if word in counted_words: 19 counted_words[word] += 1 20 else: 21 counted_words[word] = 1 22 23 # 去掉低頻的字 24 erase = [] 25 for key in counted_words: 26 if counted_words[key] <= 2: 27 erase.append(key) 28 for key in erase: 29 del counted_words[key] 30 del counted_words[']'] 31 wordPairs = sorted(counted_words.items(), key=lambda x: -x[1]) 32 33 words, _ = zip(*wordPairs) 34 # word到id的映射 35 word2num = dict((c, i + 1) for i, c in enumerate(words)) 36 num2word = dict((i, c) for i, c in enumerate(words)) 37 word2numF = lambda x: word2num.get(x, 0) 38 return word2numF, num2word, words, files_content
在每行末尾加上 ]
符號是為了標識這首詩已經結束了。我們給模型學習的方法是,給定前六個字,生成第七個字,所以在后面生成訓練數據的時候,會以6的跨度,1的步長截取文字,生成語料。如果出現了 ]
符號,說明 ]
符號之前的語句和之后的語句是兩首詩里面的內容,兩首詩之間是沒有關聯關系的,所以我們后面會舍棄掉包含 ]
符號的訓練數據。
第三,模型參數配置。預先定義模型參數和加載語料以及模型保存名稱,通過類 Config 實現。
1 class Config(object): 2 poetry_file = 'poetry.txt' 3 weight_file = 'poetry_model.h5' 4 # 根據前六個字預測第七個字 5 max_len = 6 6 batch_size = 512 7 learning_rate = 0.001
第四,構建模型,通過 PoetryModel 類實現,類的代碼結構如下:
1 class PoetryModel(object): 2 def __init__(self, config): 3 pass 4 5 def build_model(self): 6 pass 7 8 def sample(self, preds, temperature=1.0): 9 pass 10 11 def generate_sample_result(self, epoch, logs): 12 pass 13 14 def predict(self, text): 15 pass 16 17 def data_generator(self): 18 pass 19 def train(self): 20 pass
類中定義的方法具體實現功能如下:
(1)init 函數定義,通過加載 Config 配置信息,進行語料預處理和模型加載,如果模型文件存在則直接加載模型,否則開始訓練。
1 def __init__(self, config): 2 self.model = None 3 self.do_train = True 4 self.loaded_model = False 5 self.config = config 6 7 # 文件預處理 8 self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config) 9 if os.path.exists(self.config.weight_file): 10 self.model = load_model(self.config.weight_file) 11 self.model.summary() 12 else: 13 self.train() 14 self.do_train = False 15 self.loaded_model = True
(2)build_model
函數主要用 Keras 來構建網絡模型,這里使用 LSTM 的 GRU 來實現,當然直接使用 LSTM 也沒問題。
1 def build_model(self): 2 '''建立模型''' 3 input_tensor = Input(shape=(self.config.max_len,)) 4 embedd = Embedding(len(self.num2word) + 1, 300, input_length=self.config.max_len)(input_tensor) 5 lstm = Bidirectional(GRU(128, return_sequences=True))(embedd) 6 dropout = Dropout(0.6)(lstm) 7 lstm = Bidirectional(GRU(128, return_sequences=True))(embedd) 8 dropout = Dropout(0.6)(lstm) 9 flatten = Flatten()(lstm) 10 dense = Dense(len(self.words), activation='softmax')(flatten) 11 self.model = Model(inputs=input_tensor, outputs=dense) 12 optimizer = Adam(lr=self.config.learning_rate) 13 self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
(3)sample 函數,在訓練過程的每個 epoch 迭代中采樣。
1 def sample(self, preds, temperature=1.0): 2 ''' 3 當temperature=1.0時,模型輸出正常 4 當temperature=0.5時,模型輸出比較open 5 當temperature=1.5時,模型輸出比較保守 6 在訓練的過程中可以看到temperature不同,結果也不同 7 ''' 8 preds = np.asarray(preds).astype('float64') 9 preds = np.log(preds) / temperature 10 exp_preds = np.exp(preds) 11 preds = exp_preds / np.sum(exp_preds) 12 probas = np.random.multinomial(1, preds, 1) 13 return np.argmax(probas)
(4)訓練過程中,每個 epoch 打印出當前的學習情況。
1 def generate_sample_result(self, epoch, logs): 2 print("\n==================Epoch {}=====================".format(epoch)) 3 for diversity in [0.5, 1.0, 1.5]: 4 print("------------Diversity {}--------------".format(diversity)) 5 start_index = random.randint(0, len(self.files_content) - self.config.max_len - 1) 6 generated = '' 7 sentence = self.files_content[start_index: start_index + self.config.max_len] 8 generated += sentence 9 for i in range(20): 10 x_pred = np.zeros((1, self.config.max_len)) 11 for t, char in enumerate(sentence[-6:]): 12 x_pred[0, t] = self.word2numF(char) 13 14 preds = self.model.predict(x_pred, verbose=0)[0] 15 next_index = self.sample(preds, diversity) 16 next_char = self.num2word[next_index] 17 generated += next_char 18 sentence = sentence + next_char 19 print(sentence)
(5)predict 函數,用於根據給定的提示,來進行預測。
根據給出的文字,生成詩句,如果給的 text 不到四個字,則隨機補全。
1 def predict(self, text): 2 if not self.loaded_model: 3 return 4 with open(self.config.poetry_file, 'r', encoding='utf-8') as f: 5 file_list = f.readlines() 6 random_line = random.choice(file_list) 7 # 如果給的text不到四個字,則隨機補全 8 if not text or len(text) != 4: 9 for _ in range(4 - len(text)): 10 random_str_index = random.randrange(0, len(self.words)) 11 text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '。', 12 ','] else self.num2word.get( 13 random_str_index + 1) 14 seed = random_line[-(self.config.max_len):-1] 15 res = '' 16 seed = 'c' + seed 17 for c in text: 18 seed = seed[1:] + c 19 for j in range(5): 20 x_pred = np.zeros((1, self.config.max_len)) 21 for t, char in enumerate(seed): 22 x_pred[0, t] = self.word2numF(char) 23 preds = self.model.predict(x_pred, verbose=0)[0] 24 next_index = self.sample(preds, 1.0) 25 next_char = self.num2word[next_index] 26 seed = seed[1:] + next_char 27 res += seed 28 return res
(6) data_generator
函數,用於生成數據,提供給模型訓練時使用。
1 def data_generator(self): 2 i = 0 3 while 1: 4 x = self.files_content[i: i + self.config.max_len] 5 y = self.files_content[i + self.config.max_len] 6 puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》', ':'] 7 if len([i for i in puncs if i in x]) != 0: 8 i += 1 9 continue 10 if len([i for i in puncs if i in y]) != 0: 11 i += 1 12 continue 13 y_vec = np.zeros( 14 shape=(1, len(self.words)), 15 dtype=np.bool 16 ) 17 y_vec[0, self.word2numF(y)] = 1.0 18 x_vec = np.zeros( 19 shape=(1, self.config.max_len), 20 dtype=np.int32 21 ) 22 for t, char in enumerate(x): 23 x_vec[0, t] = self.word2numF(char) 24 yield x_vec, y_vec 25 i += 1
(7)train 函數,用來進行模型訓練,其中迭代次數 number_of_epoch
,是根據訓練語料長度除以 batch_size
計算的,如果在調試中,想用更小一點的number_of_epoch
,可以自定義大小,把 train 函數的第一行代碼注釋即可。
1 def train(self): 2 #number_of_epoch = len(self.files_content) // self.config.batch_size 3 number_of_epoch = 10 4 if not self.model: 5 self.build_model() 6 self.model.summary() 7 self.model.fit_generator( 8 generator=self.data_generator(), 9 verbose=True, 10 steps_per_epoch=self.config.batch_size, 11 epochs=number_of_epoch, 12 callbacks=[ 13 keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False), 14 LambdaCallback(on_epoch_end=self.generate_sample_result) 15 ] 16 )
第五,整個模型構建好以后,接下來進行模型訓練。
model = PoetryModel(Config)
訓練過程中的第1-2輪迭代:
訓練過程中的第9-10輪迭代:
雖然訓練過程寫出的詩句不怎么能看得懂,但是可以看到模型從一開始標點符號都不會用 ,到最后寫出了有一點點模樣的詩句,能看到模型變得越來越聰明了。
第六,模型作詩,模型迭代10次之后的測試,首先輸入幾個字,模型根據輸入的提示,做出詩句。
text = input("text:") sentence = model.predict(text) print(sentence)
比如輸入:小雨,模型做出的詩句為:
輸入:text:小雨
結果:小妃侯里守。雨封即客寥。俘剪舟過槽。傲老檳冬絳。
第七,繪制網絡結構圖。
模型結構繪圖,采用 Keras自帶的功能實現:
plot_model(model.model, to_file='model.png')
得到的模型結構圖如下:
本節使用 LSTM 的變形 GRU 訓練出一個能作詩的模型,當然大家可以替換訓練語料為歌詞或者小說,讓機器人自動創作不同風格的歌曲或者小說。
參考文獻以及推薦閱讀: