Keras是一个由Python编写的开源人工神经网络库,Keras包含一个简洁的API接口来呈现出你的模型的样子,这在debug过程中是非常有用的。这里有一段模仿pytorch的代码,It Is summary(), 目标就是提供完备的信息以补充 print(your_model) 的不足。
作者:sksq96
git地址:https://github.com/sksq96/pytorch-summary
安装:
pip install torchsummary
或者
git clone https://github.com/sksq96/pytorch-summary
使用范例:
from torchsummary import summary summary(your_model, input_size=(channels, H, W))
注意,input_size是建立一个前向传播的网络
CNN for MNIST
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 from torchsummary import summary 5 6 class Net(nn.Module): 7 def __init__(self): 8 super(Net, self).__init__() 9 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 10 self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 11 self.conv2_drop = nn.Dropout2d() 12 self.fc1 = nn.Linear(320, 50) 13 self.fc2 = nn.Linear(50, 10) 14 15 def forward(self, x): 16 x = F.relu(F.max_pool2d(self.conv1(x), 2)) 17 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 18 x = x.view(-1, 320) 19 x = F.relu(self.fc1(x)) 20 x = F.dropout(x, training=self.training) 21 x = self.fc2(x) 22 return F.log_softmax(x, dim=1) 23 24 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0 25 model = Net().to(device) 26 27 summary(model, (1, 28, 28)) 28 29 >>>>>: 30 ---------------------------------------------------------------- 31 Layer (type) Output Shape Param # 32 ================================================================ 33 Conv2d-1 [-1, 10, 24, 24] 260 34 Conv2d-2 [-1, 20, 8, 8] 5,020 35 Dropout2d-3 [-1, 20, 8, 8] 0 36 Linear-4 [-1, 50] 16,050 37 Linear-5 [-1, 10] 510 38 ================================================================ 39 Total params: 21,840 40 Trainable params: 21,840 41 Non-trainable params: 0 42 ---------------------------------------------------------------- 43 Input size (MB): 0.00 44 Forward/backward pass size (MB): 0.06 45 Params size (MB): 0.08 46 Estimated Total Size (MB): 0.15 47 ----------------------------------------------------------------
VGG16
1 import torch 2 from torchvision import models 3 from torchsummary import summary 4 5 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 vgg = models.vgg16().to(device) 7 8 summary(vgg, (3, 224, 224)) 9 10 >>>>>: 11 ---------------------------------------------------------------- 12 Layer (type) Output Shape Param # 13 ================================================================ 14 Conv2d-1 [-1, 64, 224, 224] 1,792 15 ReLU-2 [-1, 64, 224, 224] 0 16 Conv2d-3 [-1, 64, 224, 224] 36,928 17 ReLU-4 [-1, 64, 224, 224] 0 18 MaxPool2d-5 [-1, 64, 112, 112] 0 19 Conv2d-6 [-1, 128, 112, 112] 73,856 20 ReLU-7 [-1, 128, 112, 112] 0 21 Conv2d-8 [-1, 128, 112, 112] 147,584 22 ReLU-9 [-1, 128, 112, 112] 0 23 MaxPool2d-10 [-1, 128, 56, 56] 0 24 Conv2d-11 [-1, 256, 56, 56] 295,168 25 ReLU-12 [-1, 256, 56, 56] 0 26 Conv2d-13 [-1, 256, 56, 56] 590,080 27 ReLU-14 [-1, 256, 56, 56] 0 28 Conv2d-15 [-1, 256, 56, 56] 590,080 29 ReLU-16 [-1, 256, 56, 56] 0 30 MaxPool2d-17 [-1, 256, 28, 28] 0 31 Conv2d-18 [-1, 512, 28, 28] 1,180,160 32 ReLU-19 [-1, 512, 28, 28] 0 33 Conv2d-20 [-1, 512, 28, 28] 2,359,808 34 ReLU-21 [-1, 512, 28, 28] 0 35 Conv2d-22 [-1, 512, 28, 28] 2,359,808 36 ReLU-23 [-1, 512, 28, 28] 0 37 MaxPool2d-24 [-1, 512, 14, 14] 0 38 Conv2d-25 [-1, 512, 14, 14] 2,359,808 39 ReLU-26 [-1, 512, 14, 14] 0 40 Conv2d-27 [-1, 512, 14, 14] 2,359,808 41 ReLU-28 [-1, 512, 14, 14] 0 42 Conv2d-29 [-1, 512, 14, 14] 2,359,808 43 ReLU-30 [-1, 512, 14, 14] 0 44 MaxPool2d-31 [-1, 512, 7, 7] 0 45 Linear-32 [-1, 4096] 102,764,544 46 ReLU-33 [-1, 4096] 0 47 Dropout-34 [-1, 4096] 0 48 Linear-35 [-1, 4096] 16,781,312 49 ReLU-36 [-1, 4096] 0 50 Dropout-37 [-1, 4096] 0 51 Linear-38 [-1, 1000] 4,097,000 52 ================================================================ 53 Total params: 138,357,544 54 Trainable params: 138,357,544 55 Non-trainable params: 0 56 ---------------------------------------------------------------- 57 Input size (MB): 0.57 58 Forward/backward pass size (MB): 218.59 59 Params size (MB): 527.79 60 Estimated Total Size (MB): 746.96 61 ----------------------------------------------------------------