tensorflow 模型浮點數計算量和參數量估計


TensorFlow 模型浮點數計算量和參數量統計
2018-08-28

本博文整理了如何對一個 TensorFlow 模型的浮點數計算量(FLOPs)和參數量進行統計。
stats_graph.py

import tensorflow as tf
def stats_graph(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))

利用高斯分布對變量進行初始化會耗費一定的 FLOP

C[25,9]=A[25,16]B[16,9] FLOPs=(16+15)×(25×9)=6975FLOPs(inTFstyle)=(16+16)×(25×9)=7200total_parameters=25×16+16×9=544

with tf.Graph().as_default() as graph:
    A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A')
    B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B')
    C = tf.matmul(A, B, name='ouput')
    
    stats_graph(graph)

輸出為:
FLOPs: 8288; Trainable params: 544

利用常量初始化器對變量進行初始化不會耗費 FLOP

with tf.Graph().as_default() as graph:
    A = tf.get_variable(initializer=tf.constant_initializer(value=1, dtype=tf.float32), shape=(25, 16), name='A')
    B = tf.get_variable(initializer=tf.zeros_initializer(dtype=tf.float32), shape=(16, 9), name='B')
    C = tf.matmul(A, B, name='ouput')
    
    stats_graph(graph)

輸出為:
FLOPs: 7200; Trainable params: 544

Frozen graph

通常我們對耗費在初始化上的 FLOPs 並不感興趣,因為它是發生在訓練過程之前且是一次性的,我們感興趣的是模型部署之后在生產環境下的 FLOPs。我們可以通過 Freeze 計算圖的方式得到除去初始化 FLOPs 的、模型部署后推斷過程中耗費的 FLOPs。

from tensorflow.python.framework import graph_util
def load_pb(pb):
    with tf.gfile.GFile(pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph
with tf.Graph().as_default() as graph:
    # ***** (1) Create Graph *****
    A = tf.Variable(initial_value=tf.random_normal([25, 16]))
    B = tf.Variable(initial_value=tf.random_normal([16, 9]))
    C = tf.matmul(A, B, name='output')
    
    print('stats before freezing')
    stats_graph(graph)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # ***** (2) freeze graph *****
        output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
        with tf.gfile.GFile('graph.pb', "wb") as f:
            f.write(output_graph.SerializeToString())
# ***** (3) Load frozen graph *****
graph = load_pb('./graph.pb')
print('stats after freezing')
stats_graph(graph)

輸出為:

stats before freezing
FLOPs: 8288; Trainable params: 544
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
stats after freezing
FLOPs: 7200; Trainable params: 0

與 Keras 的結合

from keras import backend as K
from keras.layers import Dense
from keras.models import Sequential
from keras.initializers import Constant
model = Sequential()
model.add(Dense(32, input_dim=4, bias_initializer=Constant(value=0), kernel_initializer=Constant(value=1)))
sess = K.get_session()
graph = sess.graph
stats_graph(graph)

輸出為:
FLOPs: 0; Trainable params: 160
Using TensorFlow backend.
2 ops no flops stats due to incomplete shapes.
2 ops no flops stats due to incomplete shapes.
model.summary()


Layer (type) Output Shape Param #

dense_1 (Dense) (None, 32) 160

Total params: 160
Trainable params: 160
Non-trainable params: 0


DL

About

This is Robert Lexis (FengCun Li). To see the world, things dangerous to come to, to see behind walls, to draw closer, to find each other and to feel. That is the purpose of LIFE.
Recent Posts

Static variable in inline
Iterator invalidation rul
Emplace back
Perfect forward


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM