ptflops——用于分析 PyTorch 模型计算复杂度

时间:2025-02-19 09:05:23

1. ptflops使用

ptflops 是一个用于分析 PyTorch 模型计算复杂度的工具包,它可以帮助开发者快速了解模型的 FLOPs (Floating Point Operations) 和参数量,从而进行模型优化和选择。

1.1. 安装

首先,需要安装 ptflops。可以使用 pip 进行安装:

pip install ptflops

1.2. 基本用法

ptflops 的基本用法,示例一,使用torchvision模型:

from ptflops import get_model_complexity_info
import torchvision.models as models

# 创建一个模型实例
model = models.resnet18()

# 获取模型的 FLOPs 和参数量
macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)

# 打印结果
print(f'Computational complexity: {macs}, Number of parameters: {params}')

这段代码首先导入 get_model_complexity_info 函数,然后创建一个 ResNet-18 模型实例。接着,调用 get_model_complexity_info 函数,传入模型实例和输入数据的形状,以及一些可选参数。as_strings=True 表示将 FLOPs 和参数量以字符串形式返回,print_per_layer_stat=True 表示打印每一层的 FLOPs 和参数量。最后,打印输出模型的 FLOPs 和参数量。

ptflops 的基本用法,示例二,使用timm模型:

import timm
from ptflops import get_model_complexity_info  # Flops counting tool for neural networks in pytorch framework

model = timm.create_model('resnet50', pretrained=True)

macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=False,verbose=False)
print(f'Computational complexity: {macs}, Number of parameters: {params}')

1.3. 高级用法

除了基本用法外,ptflops 还提供了一些高级功能,可以更灵活地分析模型的计算复杂度。

1.3.1. 自定义输入

可以通过 custom_input 参数来自定义输入数据。例如,如果模型需要多个输入,或者输入数据的形状与默认值不同,可以使用这个参数。

macs, params = get_model_complexity_info(model, [(3, 224, 224), (1, 128)], as_strings=True, print_per_layer_stat=True, custom_input=[torch.randn(3, 224, 224), torch.randn(1, 128)])

1.3.2. 忽略特定层

可以通过 ignore_layers 参数来忽略特定层的计算复杂度。例如,如果想忽略模型中的某些层,可以将它们的名字传递给这个参数。

macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True, ignore_layers=['layer4'])

1.3.3. 指定算子

可以通过 operators 参数来指定要计算的算子类型。默认情况下,ptflops 会计算所有算子的 FLOPs。如果只想计算某些特定算子的 FLOPs,可以将它们的类型传递给这个参数。

macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True, operators=['Conv2d', 'Linear'])

1.3.4. 使用不同的 backend

ptflops 支持不同的 backend 来计算 FLOPs。可以通过 backend 参数来指定要使用的 backend。目前支持的 backend 有 'pytorch''aten'

macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True, backend='aten')

1.4. 总结

ptflops 是一款功能强大的 PyTorch 模型计算复杂度分析工具,可以帮助开发者快速了解模型的 FLOPs 和参数量,从而进行模型优化和选择。除了基本用法外,ptflops 还提供了一些高级功能,可以更灵活地分析模型的计算复杂度。

2. 关于get_model_complexity_info 返回值macsparams 的说明

2.1. macsparams 的含义及计算方法

  • macs (Multiply-Accumulate Operations): 指的是模型中乘法和加法操作的总次数。在深度学习中,乘加操作(Multiply-Accumulate)是最常见的运算,例如卷积、线性变换等。一个 MACs 操作包含一个乘法和一个加法。

    • 计算方法: ptflops 通过分析模型的结构和每一层的运算,统计出模型中所有乘加操作的次数。具体来说,它会遍历模型的每一层,根据该层的运算类型(如卷积、线性变换等)和输入输出的形状,计算出该层所需的乘加操作次数,然后将所有层的乘加操作次数累加起来,得到总的 MACs。
  • params (Parameters): 指的是模型中需要训练的参数的总数量。参数是模型中可学习的部分,例如卷积核、权重矩阵等。

    • 计算方法: ptflops 通过分析模型的结构,统计出模型中所有需要训练的参数的数量。具体来说,它会遍历模型的每一层,根据该层的参数类型和形状,计算出该层参数的数量,然后将所有层的参数数量累加起来,得到总的参数量。

2.2. 为什么 macsparams 能表达模型的复杂度

  • macs: macs 反映了模型的计算量大小。macs 越大,表示模型需要更多的计算资源和时间来完成推理过程。因此,macs 可以用来衡量模型的计算复杂度。

  • params: params 反映了模型的存储空间大小。params 越大,表示模型需要更多的存储空间来保存模型参数。此外,params 也在一定程度上影响模型的训练难度和过拟合风险。因此,params 可以用来衡量模型的模型复杂度。

通常来说,macsparams 越大,模型的复杂度就越高。但是,模型的复杂度并不完全由 macsparams 决定,还受到其他因素的影响,例如模型的结构、激活函数等。

2.3. macs 表示 multiply-add operations 吗?那除法运算、减法运算、指数运算等不考虑了吗?

是的,macs 主要表示 multiply-add operations。虽然除法、减法、指数运算等也属于模型的计算量,但它们通常在深度学习模型中占比较小,因此在计算模型复杂度时,通常只考虑乘加运算。

ptflops 在计算 macs 时,主要考虑以下几种运算:

  • 卷积运算: 卷积运算是深度学习中最重要的运算之一,它包含了大量的乘加操作。
  • 线性变换: 线性变换(如全连接层)也包含了大量的乘加操作。
  • 激活函数: 激活函数(如 ReLU、Sigmoid 等)通常只包含少量的加法、乘法和指数运算,因此在计算 macs 时通常忽略不计。
  • 其他运算: 其他运算(如除法、减法、指数运算等)在深度学习模型中占比较小,因此在计算 macs 时也通常忽略不计。

需要注意的是,ptflops 只是一个近似的计算工具,它可能无法精确计算出模型的所有计算量。但是,对于大多数深度学习模型来说,ptflops 的计算结果已经足够用来衡量模型的复杂度了。