paddle08-高层API(paddle.Model/summary/flops)


🌟 高层API 相关 3个

  • 1.class paddle.Model(net, inputs, labels) 一个具备训练、测试、推理的神经网络

    该对象同时支持静态图和动态图模式,飞桨框架默认为动态图模式,  
    通过 paddle.enable_static() 来切换到静态图模式。需要注意的是,需要在实例化 Model 对象之前完成切换。
    
    * network (paddle.nn.Layer): 是 paddle.nn.Layer 的一个实例
    * inputs: 
    * labels: 
    
    
    • 方法:train_batch(inputs, labels= None)
    • 方法: eval_batch(inputs, labels= None)
    • 方法: predict_batch(inputs)
    • 方法: save(path, training= True)
    • 方法: load(path, skip_mismatch=False, reset_optimizer= False)
    • 方法: parameters(*args, **kwargs)
    • 方法: prepare(optimizer=None, loss=None, metrics= None)
    • 方法: fit(train_data= None, eval_data= None, batch_size=1, epochs= 1, eval_freq= 1, log_freq=10, save_dir= None
      save_freq= 1, verbose= 2, drop_last= False, shuffle= True, num_workers= 0, callbacks= None)
    • 方法: evaluate(eval_data, batch_size=1, log_freq= 10, verbose= 2, num_workers=0, callbacks= None)
    • 方法: predict(test_data, batch_size=1, num_workers=0, stack_outputs= False, callbacks= None)
    • 方法: summary(input_size= None, batch_size= None, dtype= None)
  • 2.paddle.summary(net, input_size, dtypes=None) 打印网络的基础结构和参数信息

    返回:字典,包含了总的参数量和总的可训练的参数量。
    
    * net: Layer, 网络实例,必须是 Layer 的子类 
    * input_size: tuple/list/InputSpec   输入张量的大小. 
    如果网络只有一个输入,那么该值需要设定为tuple或InputSpec。  
    如果模型有多个输入。那么该值需要设定为list[tuple|InputSpec],包含每个输入的shape。  
    
    import paddle
    import paddle.nn as nn
    
    class LeNet(nn.Layer):
        def __init__(self, num_classes=10):
            super(LeNet, self).__init__()
            self.num_classes = num_classes
            self.features = nn.Sequential(
                nn.Conv2D(
                    1, 6, 3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2D(2, 2),
                nn.Conv2D(
                    6, 16, 5, stride=1, padding=0),
                nn.ReLU(),
                nn.MaxPool2D(2, 2))
    
            if num_classes > 0:
                self.fc = nn.Sequential(
                    nn.Linear(400, 120),
                    nn.Linear(120, 84),
                    nn.Linear(
                        84, 10))
    
        def forward(self, inputs):
            x = self.features(inputs)
    
            if self.num_classes > 0:
                x = paddle.flatten(x, 1)
                x = self.fc(x)
            return x
    
    lenet = LeNet()
    
    params_info = paddle.summary(lenet, (1, 1, 28, 28))
    print(params_info)
    # ---------------------------------------------------------------------------
    # Layer (type)       Input Shape          Output Shape         Param #
    # ===========================================================================
    # Conv2D-11      [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
    #     ReLU-11       [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
    # MaxPool2D-11     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
    # Conv2D-12      [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
    #     ReLU-12      [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
    # MaxPool2D-12    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
    # Linear-16         [[1, 400]]            [1, 120]           48,120
    # Linear-17         [[1, 120]]            [1, 84]            10,164
    # Linear-18         [[1, 84]]             [1, 10]              850
    # ===========================================================================
    # Total params: 61,610
    # Trainable params: 61,610
    # Non-trainable params: 0
    # ---------------------------------------------------------------------------
    # Input size (MB): 0.00
    # Forward/backward pass size (MB): 0.11
    # Params size (MB): 0.24
    # Estimated Total Size (MB): 0.35
    # ---------------------------------------------------------------------------
    # {'total_params': 61610, 'trainable_params': 61610}
        
    
  • 3.paddle.flops(net, input_size= None, custom_ops=None, print_detail=False) 打印网络的基础结构和参数信息

    和 paddle.summary() 差不多
    


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM