tensorflow中出現{TypeError}unhashable type: 'numpy.ndarray'


本人實驗中使用feed的方式填充數據,sess處的代碼如下:

1 with tf.Session() as sess:
2     init = tf.global_variables_initializer()
3     sess.run(init)
4     for epoch in range(a.epochs):
5         input, target = load_batch_data(batch_size=16, a=a)
6         batch_input = input.astype(np.float32)
7         batch_target = target.astype(np.float32)
8         sess.run(predict_real, feed_dict={input: batch_input, target: batch_target})

運行的時候出現:{TypeError}unhashable type: 'numpy.ndarray'

后  來  發  現:

在session外邊定義input和target的時候是這么寫的:

1 input = tf.placeholder(dtype=tf.float32, shape=[None, image_size, image_size, num_channels])
2 target = tf.placeholder(dtype=tf.float32, shape=[None, image_size, image_size, num_channels])

然而,我在開啟session后又定義了input,target。這導致我在運行下面這行代碼的時候,

1 sess.run(predict_real, feed_dict={input: batch_input, target: batch_target})

出現了{TypeError}unhashable type: 'numpy.ndarray'這樣的錯誤。然而此input和target非session外面的input和target。知道是這個原因后,改正的話就很簡單了,修改session內input和target的名稱即可,如下:

 1     with tf.Session() as sess:
 2         init = tf.global_variables_initializer()
 3         sess.run(init)
 4         if a.mode == 'train':
 5             for epoch in range(a.epochs):
 6                 batch_input, batch_target = load_batch_data(a=a)
 7                 batch_input = batch_input.astype(np.float32)
 8                 batch_target = batch_target.astype(np.float32)
 9                 sess.run(model, feed_dict={input: batch_input, target: batch_target})
10                 print('epoch' + str(epoch) + ':')
11             saver.save(sess, 'model_parameter/train.ckpt')
12             print('training finished!!!')
13         elif a.mode == 'test':
14             #ceshi
15             ckpt = tf.train.latest_checkpoint(a.checkpoint)
16             saver.restore(sess, ckpt)
17             # 獲取測試時候的圖像,然后添加標簽
18             batch_input, _ = load_batch_data(a=a)
19             # batch_input = batch_input / 255.
20             batch_input = batch_input.astype(np.float32)
21             generator_output = sess.run(test_output, feed_dict={input: batch_input})
22             # 對結果進行處理,圖像通道上減去3,得到rgb圖像
23             result = process_generator_output(generator_output)
24             if result:
25                 print('測試完成!')
26         else:
27             print('the MODE is not avaliable...')

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM