查看模型流程、tensor的變化、參數量
example:
from torchinfo import summary
for X, y in train_dl:
print(summary(model, X.shape))
break
output:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
PLModel -- --
├─RNNModel: 1-1 [256, 5] --
│ └─ConvNormPool: 2-1 [256, 128, 93] --
│ │ └─Conv1d: 3-1 [256, 128, 183] 768
│ │ └─BatchNorm1d: 3-2 [256, 128, 183] 256
│ │ └─Swish: 3-3 [256, 128, 183] --
│ │ └─Conv1d: 3-4 [256, 128, 183] 82,048
│ │ └─BatchNorm1d: 3-5 [256, 128, 183] 256
│ │ └─Swish: 3-6 [256, 128, 183] --
│ │ └─Conv1d: 3-7 [256, 128, 183] 82,048
│ │ └─BatchNorm1d: 3-8 [256, 128, 183] 256
│ │ └─Swish: 3-9 [256, 128, 183] --
│ │ └─MaxPool1d: 3-10 [256, 128, 93] --
│ └─ConvNormPool: 2-2 [256, 128, 46] --
│ │ └─Conv1d: 3-11 [256, 128, 89] 82,048
│ │ └─BatchNorm1d: 3-12 [256, 128, 89] 256
│ │ └─Swish: 3-13 [256, 128, 89] --
│ │ └─Conv1d: 3-14 [256, 128, 89] 82,048
│ │ └─BatchNorm1d: 3-15 [256, 128, 89] 256
│ │ └─Swish: 3-16 [256, 128, 89] --
│ │ └─Conv1d: 3-17 [256, 128, 89] 82,048
│ │ └─BatchNorm1d: 3-18 [256, 128, 89] 256
│ │ └─Swish: 3-19 [256, 128, 89] --
│ │ └─MaxPool1d: 3-20 [256, 128, 46] --
│ └─RNN: 2-3 [256, 128, 256] --
│ │ └─LSTM: 3-21 [256, 128, 256] 180,224
│ └─AdaptiveAvgPool1d: 2-4 [256, 128, 1] --
│ └─Linear: 2-5 [256, 5] 645
==========================================================================================
Total params: 593,413
Trainable params: 593,413
Non-trainable params: 0
Total mult-adds (G): 19.24
==========================================================================================
Input size (MB): 0.19
Forward/backward pass size (MB): 494.94
Params size (MB): 2.37
Estimated Total Size (MB): 497.50
==========================================================================================