【pytorch】torchsummary打印模型结构和参数信息

时间:2025-03-31 08:08:12

目录

一. 安装

二. 调用

三. 参数设置

四. 使用示例


一. 安装

Anaconda prompt 中激活虚拟环境

输入pip install torchsummary进行安装

pip install torchsummary

二. 调用

from torchsummary import summary

可以成功调用就是安装成功 

三. 参数设置

(model, input_size, batch_size=-1, device="cuda")
  • model:pytorch 模型
  • input_size:模型输入size,形状为 C,H ,W
  • batch_size:batch_size,默认为 -1,在展示模型每层输出的形状时显示的 batch_size
  • device:“cuda"或者"cpu”,默认‘cuda’,如果用cpu,记得更改

四. 使用示例

from torchsummary import summary
import torch
from  import resnet18  # 以 resnet18 为例

input_shape = [512, 512]  # 设置输入大小
device = ('cuda' if .is_available() else 'cpu')  # 选择是否使用GPU
model = resnet18().to(device)   # 实例化网络
summary(model, (3, input_shape[0], input_shape[1]))

 输出结果:


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 256, 256]           9,408
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
         MaxPool2d-4         [-1, 64, 128, 128]               0
            Conv2d-5         [-1, 64, 128, 128]          36,864
       BatchNorm2d-6         [-1, 64, 128, 128]             128
              ReLU-7         [-1, 64, 128, 128]               0
            Conv2d-8         [-1, 64, 128, 128]          36,864
       BatchNorm2d-9         [-1, 64, 128, 128]             128
             ReLU-10         [-1, 64, 128, 128]               0
       BasicBlock-11         [-1, 64, 128, 128]               0
           Conv2d-12         [-1, 64, 128, 128]          36,864
      BatchNorm2d-13         [-1, 64, 128, 128]             128
             ReLU-14         [-1, 64, 128, 128]               0
           Conv2d-15         [-1, 64, 128, 128]          36,864
      BatchNorm2d-16         [-1, 64, 128, 128]             128
             ReLU-17         [-1, 64, 128, 128]               0
       BasicBlock-18         [-1, 64, 128, 128]               0
           Conv2d-19          [-1, 128, 64, 64]          73,728
      BatchNorm2d-20          [-1, 128, 64, 64]             256
             ReLU-21          [-1, 128, 64, 64]               0
           Conv2d-22          [-1, 128, 64, 64]         147,456
      BatchNorm2d-23          [-1, 128, 64, 64]             256
           Conv2d-24          [-1, 128, 64, 64]           8,192
      BatchNorm2d-25          [-1, 128, 64, 64]             256
             ReLU-26          [-1, 128, 64, 64]               0
       BasicBlock-27          [-1, 128, 64, 64]               0
           Conv2d-28          [-1, 128, 64, 64]         147,456
      BatchNorm2d-29          [-1, 128, 64, 64]             256
             ReLU-30          [-1, 128, 64, 64]               0
           Conv2d-31          [-1, 128, 64, 64]         147,456
      BatchNorm2d-32          [-1, 128, 64, 64]             256
             ReLU-33          [-1, 128, 64, 64]               0
       BasicBlock-34          [-1, 128, 64, 64]               0
           Conv2d-35          [-1, 256, 32, 32]         294,912
      BatchNorm2d-36          [-1, 256, 32, 32]             512
             ReLU-37          [-1, 256, 32, 32]               0
           Conv2d-38          [-1, 256, 32, 32]         589,824
      BatchNorm2d-39          [-1, 256, 32, 32]             512
           Conv2d-40          [-1, 256, 32, 32]          32,768
      BatchNorm2d-41          [-1, 256, 32, 32]             512
             ReLU-42          [-1, 256, 32, 32]               0
       BasicBlock-43          [-1, 256, 32, 32]               0
           Conv2d-44          [-1, 256, 32, 32]         589,824
      BatchNorm2d-45          [-1, 256, 32, 32]             512
             ReLU-46          [-1, 256, 32, 32]               0
           Conv2d-47          [-1, 256, 32, 32]         589,824
      BatchNorm2d-48          [-1, 256, 32, 32]             512
             ReLU-49          [-1, 256, 32, 32]               0
       BasicBlock-50          [-1, 256, 32, 32]               0
           Conv2d-51          [-1, 512, 16, 16]       1,179,648
      BatchNorm2d-52          [-1, 512, 16, 16]           1,024
             ReLU-53          [-1, 512, 16, 16]               0
           Conv2d-54          [-1, 512, 16, 16]       2,359,296
      BatchNorm2d-55          [-1, 512, 16, 16]           1,024
           Conv2d-56          [-1, 512, 16, 16]         131,072
      BatchNorm2d-57          [-1, 512, 16, 16]           1,024
             ReLU-58          [-1, 512, 16, 16]               0
       BasicBlock-59          [-1, 512, 16, 16]               0
           Conv2d-60          [-1, 512, 16, 16]       2,359,296
      BatchNorm2d-61          [-1, 512, 16, 16]           1,024
             ReLU-62          [-1, 512, 16, 16]               0
           Conv2d-63          [-1, 512, 16, 16]       2,359,296
      BatchNorm2d-64          [-1, 512, 16, 16]           1,024
             ReLU-65          [-1, 512, 16, 16]               0
       BasicBlock-66          [-1, 512, 16, 16]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 328.01
Params size (MB): 44.59
Estimated Total Size (MB): 375.60
----------------------------------------------------------------

Process finished with exit code 0

如图,torchsummary 可以查看网络的顺序结构,显示每一层的类型、out shape和参数量; 还有网络参数量,网络模型大小; fp/bp 一次需要的内存大小等信息。