循環神經網絡(RNN)的代碼實現


代碼部分

import tensorflow as tf
import tensorflow.contrib as rnn #引入RNN
form tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data/", one_hot=True)
batch_size = 128 #定義參數
#定義訓練數據
x = tf.placeholder("float", [None, 28, 28])
y = tf.placeholder("float", [None, 10])
#定義w和b
weights = {
    'out': tf.Variable(tf.random_normal([128, 10]))}
biases = {
    'out': tf.Variable(tf.random_normal([10]))
}
def RNN(x, weights, biases):
    #按照RNN的方式處理輸入層
    x = tf.unstack(x, 28, 1)
    #lstm層
    #forget_bias (默認為1)到遺忘門的偏置,為了減少在開始訓練時遺忘的規模
    lstm_cell = rnn.BasicLSTMCell(128, forget_bias=1.0)
    #獲得lstm層的輸出
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
    #得到最后一層的輸出
    return rf.matmul(outputs[-1], weights['out'])+biases['out']
    
#預測值
pred = RNN(x, weights,biases)
#定義代價函數和最優算法
#尋找全局最優點的優化算法,引入了二次方梯度矯正
#AdamOptimizer 不容易陷於局部優點,速度更快
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pre, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).mnimizer(cost)
#結果對比
correct_pred = tf.wqual(tf.argmax(pred, 1),tf.argmax(y, 1))
#求正確率
accuracy = tf.reduce_mean(tf.case(corrext_pred, tf.float32))
#初始化所有參數
init = tf.initializer_all_variables()
with tf.Session() as sess:
    sess.run(init)
    step = 1
    
    while step * batch_size < 100000:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x = batch_x.reshape((batch_size,28,28))
        sess.run(optimizer, feed_dict={x: batch_x,y:batch_y})
        if step % 10 == 0:
            acc = sess.run(accuracy, feed_sict={x: batch_x,y:batch_y})
            loss = sess.run(cost, feed_dict={x: batch_x, y:batch_y})
            print "iter" + str(step * batch_size) + ",minibatch loss ="+ loss + acc
        step += 1
    print "optimization finished"
    #數據驗證
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1,28,28))
    test_label = mnist.test.labels[:test_len]
    print "testing accuracy"+sess.run(accuracy, feed_dict={x: test_data,y: test_label})


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM