首先我們分析一下下面的代碼:
import tensorflow as tf import numpy as np a=tf.constant([[1., 2., 3.],[4., 5., 6.]]) b=np.float32(np.random.randn(3,2)) #c=tf.matmul(a,b) c=tf.multiply(a,b) init=tf.global_variables_initializer() with tf.Session() as sess: print(c.eval())
問題是上面的代碼編譯正確嗎?編譯一下就知道,錯誤信息如下:
ValueError: Dimensions must be equal, but are 2 and 3 for 'Mul' (op: 'Mul') with input shapes: [2,3], [3,2].
顯然,tf.multiply()表示點積,因此維度要一樣。而tf.matmul()表示普通的矩陣乘法。
而且tf.multiply(a,b)和tf.matmul(a,b)都要求a和b的類型必須一致。但是之間存在着細微的區別。
在tf中所有返回的tensor,不管傳進去是什么類型,傳出來的都是numpy ndarray對象。
看看官網API介紹:
tf.matmul( a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None ) tf.multiply( x, y, name=None )
但是tf.matmul(a,b)函數不僅要求a和b的類型必須完全一致,同時返回的tensor類型同a和b一致;而tf.multiply(a,b)函數僅要求a和b的類型顯式一致,同時返回的tensor類型與a一致,即在不聲明類型的情況下,編譯不報錯。
例如:
#類型一致,可以運行 import tensorflow as tf import numpy as np a=tf.constant([[1, 2, 3],[4, 5, 6]],dtype=np.float32) b=np.float32(np.random.randn(3,2)) c=tf.matmul(a,b) #c=tf.multiply(a,b) init=tf.global_variables_initializer() with tf.Session() as sess: print (type(c.eval()),type(a.eval()),type(b))
#類型不一致,不可以運行 import tensorflow as tf import numpy as np a=tf.constant([[1, 2, 3],[4, 5, 6]]) b=np.float32(np.random.randn(3,2)) c=tf.matmul(a,b) #c=tf.multiply(a,b) init=tf.global_variables_initializer() with tf.Session() as sess: print (type(c.eval()),type(a.eval()),type(b))
#類型不一致,可以運行,結果的類型和a一致 import tensorflow as tf import numpy as np a=tf.constant([[1, 2, 3],[4, 5, 6]]) b=np.float32(np.random.randn(2,3)) #c=tf.matmul(a,b) c=tf.multiply(a,b) init=tf.global_variables_initializer() with tf.Session() as sess: print (c.eval()) print (type(c.eval()),type(a.eval()),type(b))
#類型不一致,不可以運行 import tensorflow as tf import numpy as np a=tf.constant([[1, 2, 3],[4, 5, 6]], dtype=np.float32) b=tf.constant([[1, 2, 3],[4, 5, 6]], dtype=np.int32) #c=tf.matmul(a,b) c=tf.multiply(a,b) init=tf.global_variables_initializer() with tf.Session() as sess: print (c.eval()) print (type(c.eval()),type(a.eval()),type(b))