基於 LSTM 輕松生成各種古詩


整個過程分為以下步驟完成:

  1. 語料准備
  2. 語料預處理
  3. 模型參數配置
  4. 構建模型
  5. 訓練模型
  6. 模型作詩
  7. 繪制模型網絡結構圖

下面一步步來構建和訓練一個會寫詩的模型。

第一,語料准備。一共四萬多首古詩,每行一首詩,標題在預處理的時候已經去掉了。

第二,文件預處理。首先,機器並不懂每個中文漢字代表的是什么,所以要將文字轉換為機器能理解的形式,這里我們采用 One-Hot 的形式,這樣詩句中的每個字都能用向量來表示,下面定義函數 preprocess_file() 來處理。

 1 puncs = [']', '[', '', '', '{', '}', '', '', '']
 2 
 3 
 4 def preprocess_file(Config):
 5     # 語料文本內容
 6     files_content = ''
 7     with open(Config.poetry_file, 'r', encoding='utf-8') as f:
 8         for line in f:
 9             # 每行的末尾加上"]"符號代表一首詩結束
10             for char in puncs:
11                 line = line.replace(char, "")
12             files_content += line.strip() + "]"
13 
14     words = sorted(list(files_content))
15     words.remove(']')
16     counted_words = {}
17     for word in words:
18         if word in counted_words:
19             counted_words[word] += 1
20         else:
21             counted_words[word] = 1
22 
23     # 去掉低頻的字
24     erase = []
25     for key in counted_words:
26         if counted_words[key] <= 2:
27             erase.append(key)
28     for key in erase:
29         del counted_words[key]
30     del counted_words[']']
31     wordPairs = sorted(counted_words.items(), key=lambda x: -x[1])
32 
33     words, _ = zip(*wordPairs)
34     # word到id的映射
35     word2num = dict((c, i + 1) for i, c in enumerate(words))
36     num2word = dict((i, c) for i, c in enumerate(words))
37     word2numF = lambda x: word2num.get(x, 0)
38     return word2numF, num2word, words, files_content

在每行末尾加上 ] 符號是為了標識這首詩已經結束了。我們給模型學習的方法是,給定前六個字,生成第七個字,所以在后面生成訓練數據的時候,會以6的跨度,1的步長截取文字,生成語料。如果出現了 ] 符號,說明 ] 符號之前的語句和之后的語句是兩首詩里面的內容,兩首詩之間是沒有關聯關系的,所以我們后面會舍棄掉包含 ] 符號的訓練數據。

第三,模型參數配置。預先定義模型參數和加載語料以及模型保存名稱,通過類 Config 實現。

1 class Config(object):
2     poetry_file = 'poetry.txt'
3     weight_file = 'poetry_model.h5'
4     # 根據前六個字預測第七個字
5     max_len = 6
6     batch_size = 512
7     learning_rate = 0.001

第四,構建模型,通過 PoetryModel 類實現,類的代碼結構如下:

 1  class PoetryModel(object):
 2         def __init__(self, config):
 3             pass
 4 
 5         def build_model(self):
 6             pass
 7 
 8         def sample(self, preds, temperature=1.0):
 9             pass
10 
11         def generate_sample_result(self, epoch, logs):
12             pass
13 
14         def predict(self, text):
15             pass
16 
17         def data_generator(self):
18             pass
19         def train(self):
20             pass

類中定義的方法具體實現功能如下:

(1)init 函數定義,通過加載 Config 配置信息,進行語料預處理和模型加載,如果模型文件存在則直接加載模型,否則開始訓練。

 1  def __init__(self, config):
 2             self.model = None
 3             self.do_train = True
 4             self.loaded_model = False
 5             self.config = config
 6 
 7             # 文件預處理
 8             self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)
 9             if os.path.exists(self.config.weight_file):
10                 self.model = load_model(self.config.weight_file)
11                 self.model.summary()
12             else:
13                 self.train()
14             self.do_train = False
15             self.loaded_model = True

(2)build_model 函數主要用 Keras 來構建網絡模型,這里使用 LSTM 的 GRU 來實現,當然直接使用 LSTM 也沒問題。

 1 def build_model(self):
 2     '''建立模型'''
 3     input_tensor = Input(shape=(self.config.max_len,))
 4     embedd = Embedding(len(self.num2word) + 1, 300, input_length=self.config.max_len)(input_tensor)
 5     lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)
 6     dropout = Dropout(0.6)(lstm)
 7     lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)
 8     dropout = Dropout(0.6)(lstm)
 9     flatten = Flatten()(lstm)
10     dense = Dense(len(self.words), activation='softmax')(flatten)
11     self.model = Model(inputs=input_tensor, outputs=dense)
12     optimizer = Adam(lr=self.config.learning_rate)
13     self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

 

(3)sample 函數,在訓練過程的每個 epoch 迭代中采樣。

 1    def sample(self, preds, temperature=1.0):
 2             '''
 3             當temperature=1.0時,模型輸出正常
 4             當temperature=0.5時,模型輸出比較open
 5             當temperature=1.5時,模型輸出比較保守
 6             在訓練的過程中可以看到temperature不同,結果也不同
 7             '''
 8             preds = np.asarray(preds).astype('float64')
 9             preds = np.log(preds) / temperature
10             exp_preds = np.exp(preds)
11             preds = exp_preds / np.sum(exp_preds)
12             probas = np.random.multinomial(1, preds, 1)
13             return np.argmax(probas)

(4)訓練過程中,每個 epoch 打印出當前的學習情況。

 1 def generate_sample_result(self, epoch, logs):
 2     print("\n==================Epoch {}=====================".format(epoch))
 3     for diversity in [0.5, 1.0, 1.5]:
 4         print("------------Diversity {}--------------".format(diversity))
 5         start_index = random.randint(0, len(self.files_content) - self.config.max_len - 1)
 6         generated = ''
 7         sentence = self.files_content[start_index: start_index + self.config.max_len]
 8         generated += sentence
 9         for i in range(20):
10             x_pred = np.zeros((1, self.config.max_len))
11             for t, char in enumerate(sentence[-6:]):
12                 x_pred[0, t] = self.word2numF(char)
13 
14             preds = self.model.predict(x_pred, verbose=0)[0]
15             next_index = self.sample(preds, diversity)
16             next_char = self.num2word[next_index]
17             generated += next_char
18             sentence = sentence + next_char
19         print(sentence)

(5)predict 函數,用於根據給定的提示,來進行預測。

根據給出的文字,生成詩句,如果給的 text 不到四個字,則隨機補全。

 1 def predict(self, text):
 2         if not self.loaded_model:
 3             return
 4         with open(self.config.poetry_file, 'r', encoding='utf-8') as f:
 5             file_list = f.readlines()
 6         random_line = random.choice(file_list)
 7         # 如果給的text不到四個字,則隨機補全
 8         if not text or len(text) != 4:
 9             for _ in range(4 - len(text)):
10                 random_str_index = random.randrange(0, len(self.words))
11                 text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '',
12                                                                                                            ''] else self.num2word.get(
13                     random_str_index + 1)
14         seed = random_line[-(self.config.max_len):-1]
15         res = ''
16         seed = 'c' + seed
17         for c in text:
18             seed = seed[1:] + c
19             for j in range(5):
20                 x_pred = np.zeros((1, self.config.max_len))
21                 for t, char in enumerate(seed):
22                     x_pred[0, t] = self.word2numF(char)
23                 preds = self.model.predict(x_pred, verbose=0)[0]
24                 next_index = self.sample(preds, 1.0)
25                 next_char = self.num2word[next_index]
26                 seed = seed[1:] + next_char
27             res += seed
28         return res

(6) data_generator 函數,用於生成數據,提供給模型訓練時使用。

 1 def data_generator(self):
 2     i = 0
 3     while 1:
 4         x = self.files_content[i: i + self.config.max_len]
 5         y = self.files_content[i + self.config.max_len]
 6         puncs = [']', '[', '', '', '{', '}', '', '', '', ':']
 7         if len([i for i in puncs if i in x]) != 0:
 8             i += 1
 9             continue
10         if len([i for i in puncs if i in y]) != 0:
11             i += 1
12             continue
13         y_vec = np.zeros(
14             shape=(1, len(self.words)),
15             dtype=np.bool
16         )
17         y_vec[0, self.word2numF(y)] = 1.0
18         x_vec = np.zeros(
19             shape=(1, self.config.max_len),
20             dtype=np.int32
21         )
22         for t, char in enumerate(x):
23             x_vec[0, t] = self.word2numF(char)
24         yield x_vec, y_vec
25         i += 1

(7)train 函數,用來進行模型訓練,其中迭代次數 number_of_epoch ,是根據訓練語料長度除以 batch_size 計算的,如果在調試中,想用更小一點的number_of_epoch ,可以自定義大小,把 train 函數的第一行代碼注釋即可。

 1 def train(self):
 2         #number_of_epoch = len(self.files_content) // self.config.batch_size
 3         number_of_epoch = 10
 4         if not self.model:
 5             self.build_model()
 6         self.model.summary()
 7         self.model.fit_generator(
 8             generator=self.data_generator(),
 9             verbose=True,
10             steps_per_epoch=self.config.batch_size,
11             epochs=number_of_epoch,
12             callbacks=[
13                 keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),
14                 LambdaCallback(on_epoch_end=self.generate_sample_result)
15             ]
16         )

第五,整個模型構建好以后,接下來進行模型訓練。

 model = PoetryModel(Config)

訓練過程中的第1-2輪迭代:

enter image description here

訓練過程中的第9-10輪迭代:

enter image description here

雖然訓練過程寫出的詩句不怎么能看得懂,但是可以看到模型從一開始標點符號都不會用 ,到最后寫出了有一點點模樣的詩句,能看到模型變得越來越聰明了。

第六,模型作詩,模型迭代10次之后的測試,首先輸入幾個字,模型根據輸入的提示,做出詩句。

    text = input("text:")
    sentence = model.predict(text)
    print(sentence)

比如輸入:小雨,模型做出的詩句為:

輸入:text:小雨

結果:小妃侯里守。雨封即客寥。俘剪舟過槽。傲老檳冬絳。

第七,繪制網絡結構圖。

模型結構繪圖,采用 Keras自帶的功能實現:

    plot_model(model.model, to_file='model.png')

得到的模型結構圖如下:

enter image description here

本節使用 LSTM 的變形 GRU 訓練出一個能作詩的模型,當然大家可以替換訓練語料為歌詞或者小說,讓機器人自動創作不同風格的歌曲或者小說。

參考文獻以及推薦閱讀:

  1. 基於 Keras 和 LSTM 的文本生成


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM