paddle08-高層API(paddle.Model/summary/flops)


🌟 高層API 相關 3個

  • 1.class paddle.Model(net, inputs, labels) 一個具備訓練、測試、推理的神經網絡

    該對象同時支持靜態圖和動態圖模式,飛槳框架默認為動態圖模式,  
    通過 paddle.enable_static() 來切換到靜態圖模式。需要注意的是,需要在實例化 Model 對象之前完成切換。
    
    * network (paddle.nn.Layer): 是 paddle.nn.Layer 的一個實例
    * inputs: 
    * labels: 
    
    
    • 方法:train_batch(inputs, labels= None)
    • 方法: eval_batch(inputs, labels= None)
    • 方法: predict_batch(inputs)
    • 方法: save(path, training= True)
    • 方法: load(path, skip_mismatch=False, reset_optimizer= False)
    • 方法: parameters(*args, **kwargs)
    • 方法: prepare(optimizer=None, loss=None, metrics= None)
    • 方法: fit(train_data= None, eval_data= None, batch_size=1, epochs= 1, eval_freq= 1, log_freq=10, save_dir= None
      save_freq= 1, verbose= 2, drop_last= False, shuffle= True, num_workers= 0, callbacks= None)
    • 方法: evaluate(eval_data, batch_size=1, log_freq= 10, verbose= 2, num_workers=0, callbacks= None)
    • 方法: predict(test_data, batch_size=1, num_workers=0, stack_outputs= False, callbacks= None)
    • 方法: summary(input_size= None, batch_size= None, dtype= None)
  • 2.paddle.summary(net, input_size, dtypes=None) 打印網絡的基礎結構和參數信息

    返回:字典,包含了總的參數量和總的可訓練的參數量。
    
    * net: Layer, 網絡實例,必須是 Layer 的子類 
    * input_size: tuple/list/InputSpec   輸入張量的大小. 
    如果網絡只有一個輸入,那么該值需要設定為tuple或InputSpec。  
    如果模型有多個輸入。那么該值需要設定為list[tuple|InputSpec],包含每個輸入的shape。  
    
    import paddle
    import paddle.nn as nn
    
    class LeNet(nn.Layer):
        def __init__(self, num_classes=10):
            super(LeNet, self).__init__()
            self.num_classes = num_classes
            self.features = nn.Sequential(
                nn.Conv2D(
                    1, 6, 3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2D(2, 2),
                nn.Conv2D(
                    6, 16, 5, stride=1, padding=0),
                nn.ReLU(),
                nn.MaxPool2D(2, 2))
    
            if num_classes > 0:
                self.fc = nn.Sequential(
                    nn.Linear(400, 120),
                    nn.Linear(120, 84),
                    nn.Linear(
                        84, 10))
    
        def forward(self, inputs):
            x = self.features(inputs)
    
            if self.num_classes > 0:
                x = paddle.flatten(x, 1)
                x = self.fc(x)
            return x
    
    lenet = LeNet()
    
    params_info = paddle.summary(lenet, (1, 1, 28, 28))
    print(params_info)
    # ---------------------------------------------------------------------------
    # Layer (type)       Input Shape          Output Shape         Param #
    # ===========================================================================
    # Conv2D-11      [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
    #     ReLU-11       [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
    # MaxPool2D-11     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
    # Conv2D-12      [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
    #     ReLU-12      [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
    # MaxPool2D-12    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
    # Linear-16         [[1, 400]]            [1, 120]           48,120
    # Linear-17         [[1, 120]]            [1, 84]            10,164
    # Linear-18         [[1, 84]]             [1, 10]              850
    # ===========================================================================
    # Total params: 61,610
    # Trainable params: 61,610
    # Non-trainable params: 0
    # ---------------------------------------------------------------------------
    # Input size (MB): 0.00
    # Forward/backward pass size (MB): 0.11
    # Params size (MB): 0.24
    # Estimated Total Size (MB): 0.35
    # ---------------------------------------------------------------------------
    # {'total_params': 61610, 'trainable_params': 61610}
        
    
  • 3.paddle.flops(net, input_size= None, custom_ops=None, print_detail=False) 打印網絡的基礎結構和參數信息

    和 paddle.summary() 差不多
    


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM