Tensorflow LSTM實現


Tensorflow[LSTM]

 

0.背景

通過對《tensorflow machine learning cookbook》第9章第3節"implementing_lstm"進行閱讀,發現如下形式可以很方便的進行訓練和預測,通過類進行定義,並利用了tf中的變量重用的能力,使得在訓練階段模型的許多變量,比如權重等,能夠直接用在預測階段。十分方便,不需要自己去做一些權重復制等事情。這里只是簡單記錄下這一小節的源碼中幾個概念性的地方。

# 定義LSTM模型 class LSTM_Model(): def __init__(self, embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size, infer_sample=False): self.embedding_size = embedding_size self.rnn_size = rnn_size #LSTM單元隱層的神經元個數 self.vocab_size = vocab_size self.infer_sample = infer_sample self.learning_rate = learning_rate#學習率 if infer_sample:#如果是inference,則batch size設為1 self.batch_size = 1 self.training_seq_len = 1 else: self.batch_size = batch_size self.training_seq_len = training_seq_len '''建立LSTM單元和初始化state''' self.lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.rnn_size) self.initial_state = self.lstm_cell.zero_state(self.batch_size, tf.float32) '''進行輸入和輸出的占位''' self.x_data = tf.placeholder(tf.int32, [self.batch_size, self.training_seq_len]) self.y_output = tf.placeholder(tf.int32, [self.batch_size, self.training_seq_len]) with tf.variable_scope('lstm_vars'): # Softmax 部分的權重 W = tf.get_variable('W', [self.rnn_size, self.vocab_size], tf.float32, tf.random_normal_initializer()) b = tf.get_variable('b', [self.vocab_size], tf.float32, tf.constant_initializer(0.0)) # Define Embedding embedding_mat = tf.get_variable('embedding_mat', [self.vocab_size, self.embedding_size], tf.float32, tf.random_normal_initializer()) embedding_output = tf.nn.embedding_lookup(embedding_mat, self.x_data) rnn_inputs = tf.split(axis=1, num_or_size_splits=self.training_seq_len, value=embedding_output) rnn_inputs_trimmed = [tf.squeeze(x, [1]) for x in rnn_inputs] # If we are inferring (generating text), we add a 'loop' function # Define how to get the i+1 th input from the i th output def inferred_loop(prev, count): # Apply hidden layer prev_transformed = tf.matmul(prev, W) + b # Get the index of the output (also don't run the gradient) prev_symbol = tf.stop_gradient(tf.argmax(prev_transformed, 1)) # Get embedded vector output = tf.nn.embedding_lookup(embedding_mat, prev_symbol) return(output) decoder = tf.contrib.legacy_seq2seq.rnn_decoder outputs, last_state = decoder(rnn_inputs_trimmed, self.initial_state, self.lstm_cell, loop_function=inferred_loop if infer_sample else None) # Non inferred outputs output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, self.rnn_size]) # Logits and output self.logit_output = tf.matmul(output, W) + b self.model_output = tf.nn.softmax(self.logit_output) loss_fun = tf.contrib.legacy_seq2seq.sequence_loss_by_example loss = loss_fun([self.logit_output],[tf.reshape(self.y_output, [-1])], [tf.ones([self.batch_size * self.training_seq_len])], self.vocab_size) self.cost = tf.reduce_sum(loss) / (self.batch_size * self.training_seq_len) self.final_state = last_state gradients, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tf.trainable_variables()), 4.5) optimizer = tf.train.AdamOptimizer(self.learning_rate) self.train_op = optimizer.apply_gradients(zip(gradients, tf.trainable_variables())) def sample(self, sess, words=ix2vocab, vocab=vocab2ix, num=10, prime_text='thou art'): state = sess.run(self.lstm_cell.zero_state(1, tf.float32)) word_list = prime_text.split() for word in word_list[:-1]: x = np.zeros((1, 1)) x[0, 0] = vocab[word] feed_dict = {self.x_data: x, self.initial_state:state} [state] = sess.run([self.final_state], feed_dict=feed_dict) out_sentence = prime_text word = word_list[-1] for n in range(num): x = np.zeros((1, 1)) x[0, 0] = vocab[word] feed_dict = {self.x_data: x, self.initial_state:state} [model_output, state] = sess.run([self.model_output, self.final_state], feed_dict=feed_dict) sample = np.argmax(model_output[0]) if sample == 0: break word = words[sample] out_sentence = out_sentence + ' ' + word return(out_sentence)

上述代碼就建立好了lstm的網絡結構,其中想要說明的重點就是,如往常一樣構建lstm結構,其中BasicLSTMCell中的權重和上述的lstm_vars一樣是有variable_scope的

# 定義訓練階段的lstm lstm_model = LSTM_Model(embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size) # 定義測試階段的lstm with tf.variable_scope(tf.get_variable_scope(), reuse=True): test_lstm_model = LSTM_Model(embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size, infer_sample=True)

上述代碼通過先建立一個訓練的lstm結構,然后采用全局變量重用的形式,使得inference的lstm中的變量都方便的使用train階段的變量。
下面是訓練和inference的代碼

# Train model train_loss = [] iteration_count = 1 for epoch in range(epochs): # Shuffle word indices random.shuffle(batches) # Create targets from shuffled batches targets = [np.roll(x, -1, axis=1) for x in batches] # Run a through one epoch print('Starting Epoch #{} of {}.'.format(epoch+1, epochs)) # Reset initial LSTM state every epoch state = sess.run(lstm_model.initial_state) for ix, batch in enumerate(batches): training_dict = {lstm_model.x_data: batch, lstm_model.y_output: targets[ix]} '''每個batch的LSTM中初始化狀態c和h,其狀態被賦值為上一個batch的LSTM的最終狀態的c和h ''' '''也就是前后相接 ''' c, h = lstm_model.initial_state training_dict[c] = state.c training_dict[h] = state.h temp_loss, state, _ = sess.run([lstm_model.cost, lstm_model.final_state, lstm_model.train_op], feed_dict=training_dict) train_loss.append(temp_loss) # Print status every 10 gens if iteration_count % 10 == 0: summary_nums = (iteration_count, epoch+1, ix+1, num_batches+1, temp_loss) print('Iteration: {}, Epoch: {}, Batch: {} out of {}, Loss: {:.2f}'.format(*summary_nums)) if iteration_count % eval_every == 0: for sample in prime_texts: print(test_lstm_model.sample(sess, ix2vocab, vocab2ix, num=10, prime_text=sample)) iteration_count += 1

在后續的訓練中只要正常訓練和測試即可,其中inference階段時候lstm中的權重,全都會自動的從訓練階段直接拿來用,在"site-packages/tensorflow/python/ops/rnn_cell_impl.py"的1240行

  scope = vs.get_variable_scope()
  with vs.variable_scope(scope) as outer_scope: weights = vs.get_variable( _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], dtype=dtype, initializer=kernel_initializer)

如上述代碼中所示,當采用了全局變量重用功能之后,就無需手動去復制train好的權重到inference階段了。


圖0.1 graph圖,左邊紅框是train的結構;右邊紅框是inference的結構


圖0.2 基於圖0.1的局部放大


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM