tensorflow BasicRNNCell調試


運行以下代碼,進入~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py和~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py單步追蹤調試

調試中import tensorflow as tf,利用tf.Session().run(variable)打印變量

查看BasicRNNCell和dynamic_rnn的實現方式

 

 1 #-*-coding:utf8-*-
 2 
 3 __author = "buyizhiyou"
 4 __date = "2017-11-20"
 5 
 6 '''
 7 單步調試,學習rnn的tf實現
 8 '''
 9 import tensorflow as tf 
10 import numpy as np
11 import pdb  
12   
13 X = tf.random_normal(shape=[2,3,4], dtype=tf.float32)#(2,3,4)==>(Batch_size,Time_steps(序列長度),Data_Vector)
14 pdb.set_trace()  
15 cell = tf.nn.rnn_cell.BasicRNNCell(10)#output_size:10,也可以換成GRUCell,LSTMAACell,BasicRNNCell  
16 state = cell.zero_state(2, tf.float32)#batch_size:2  
17 output, state = tf.nn.dynamic_rnn(cell, X, initial_state=state, time_major=False)  
18 with tf.Session() as sess:  
19     sess.run(tf.global_variables_initializer())  
20     print (output.get_shape())
21     print (sess.run(state))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM