對比TensorFlow和Pytorch的動靜態圖構建上的差異
靜態圖框架設計好了不能夠修改,且定義靜態圖時需要使用新的特殊語法,這也意味着圖設定時無法使用if、while、for-loop等結構,而是需要特殊的由框架專門設計的語法,在構建圖時,我們需要考慮到所有的情況(即各個if分支圖結構必須全部在圖中,即使不一定會在每一次運行時使用到),使得靜態圖異常龐大占用過多顯存。
以動態圖沒有這個顧慮,它兼容python的各種邏輯控制語法,最終創建的圖取決於每次運行時的條件分支選擇,下面我們對比一下TensorFlow和Pytorch的if條件分支構建圖的實現:
# Author : Hellcat
# Time : 2018/2/9
def tf_graph_if():
import numpy as np
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=(3, 4))
z = tf.placeholder(tf.float32, shape=None)
w1 = tf.placeholder(tf.float32, shape=(4, 5))
w2 = tf.placeholder(tf.float32, shape=(4, 5))
def f1():
return tf.matmul(x, w1)
def f2():
return tf.matmul(x, w2)
y = tf.cond(tf.less(z, 0), f1, f2)
with tf.Session() as sess:
y_out = sess.run(y, feed_dict={
x: np.random.randn(3, 4),
z: 10,
w1: np.random.randn(4, 5),
w2: np.random.randn(4, 5)})
return y_out
def t_graph_if():
import torch as t
from torch.autograd import Variable
x = Variable(t.randn(3, 4))
w1 = Variable(t.randn(4, 5))
w2 = Variable(t.randn(4, 5))
z = 10
if z > 0:
y = x.mm(w1)
else:
y = x.mm(w2)
return y
if __name__ == "__main__":
print(tf_graph_if())
print(t_graph_if())
計算輸出如下:
[[ 4.0871315 0.90317607 -4.65211582 0.71610922 -2.70281982]
[ 3.67874336 -0.58160967 -3.43737102 1.9781189 -2.18779659]
[ 2.6638422 -0.81783319 -0.30386463 -0.61386991 -3.80232286]]
Variable containing:
-0.2474 0.1269 0.0830 3.4642 0.2255
0.7555 -0.8057 -2.8159 3.7416 0.6230
0.9010 -0.9469 -2.5086 -0.8848 0.2499
[torch.FloatTensor of size 3x5]
個人感覺上面的對比不太完美,如果使用TensorFlow的變量來對比,上面函數應該改寫如下,
# Author : Hellcat
# Time : 2018/2/9
def tf_graph_if():
import tensorflow as tf
x = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[3, 4]))
z = tf.constant(dtype=tf.float32, value=10)
w1 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))
w2 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))
def f1():
return tf.matmul(x, w1)
def f2():
return tf.matmul(x, w2)
y = tf.cond(tf.less(z, 0), f1, f2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
y_out = sess.run(y)
return y_out
輸出沒什么變化,
[[ 1.89582038 1.12734962 0.59730953 0.99833554 0.86517167]
[ 1.2659111 0.77320379 0.63649696 0.5804953 0.82271856]
[ 1.92151642 1.64715886 1.19869363 1.31581473 1.5636673 ]]
可以看到,TensorFlow的if條件分支使用函數tf.cond(tf.less(z, 0), f1, f2)來實現,這和Pytorch直接使用if的邏輯很不同,而且,動態圖不需要feed,直接運行便可。簡單對比,可以看到Pytorch的邏輯更為簡潔,讓人很感興趣。
