pytorch中正向传播和反向传播的钩子Hook

时间:2022-11-27 14:51:05

最近在可视化特征图时遇到了钩子的问题,这里记录学习一下相关知识~


在某些情况下,我们需要对深度学习模型的前向计算和反向传播的行为做一定的修改。比如,我们想要观察深度学习模型的某一层中出现的异常值(NaN或者Inf),找出这些异常值的来源,或者想要对模块输出的张量做一定的修改。在这种情况下,可以通过在模块中引入钩子Hook来动态修改模块的行为。

一、钩子的概念

pytorch中的深度学习模型是由一个个子模块组合而成的,每个子模块有一定的输入(可以是张量和其他的python原生数据类型),并且会通过计算给出对应的输出。同样,在反向传播过程中,输入的是下一层梯度的输出,对应输出的是上一层的梯度和权重对应的梯度。钩子的引入就是为了能够在正向计算的前后和反向计算之后,对输入/输出的张量进行读取或者修改,以达到最终修改模块行为的目的。对于一个模块而言,可以有三种类型的钩子,分别在前向计算之前、前向计算之后和反向传播之后执行

二、模块执行之前的前向计算钩子

# 定义nn.Module的一个实例模块
module = ...
def hook(module, input):
    # 对模块权重或者输入进行操作的代码
    # 函数结果可以返回修改后的张量或者None
    return input
handle = module.register_forward_pre_hook(hook)

首先定义一个nn.Module的模块实例,然后定义一个钩子函数,这个函数有2个参数:模块本身和模块的输入(即forward方法的输入参数)。定义好这个钩子函数后,通过register_forward_per_hook来注册这个钩子函数,这样,在调用这个模块之前,会先进行钩子函数的调用,对模块和模块的输入做一定的修改,然后把修改后的参数传入模块中计算。注册钩子函数会返回一个句柄handle,通过调用这个句柄的remove方法可以移除这个钩子函数。

三、模块执行之后的前向计算钩子

# 定义nn.Module的一个实例模块
module = ...
def hook(module, input, output):
    # 对模块权重或者输入/输出进行操作的代码
    # 函数结果可以返回修改后的张量或者None
    return output
handle = module.register_forward_hook(hook)

该钩子函数有三个参数:module,input,output,分别是模块本身,模块的输入参数和输出参数。同样,这个钩子的注册函数会返回一个句柄,通过调用这个句柄的remove可以对这个钩子函数进行移除。

四、模块执行之后的反向传播钩子

# 定义nn.Module的一个实例模块
module = ...
def hook(module, grad_input, grad_output):
    # 对模块权重或者输入/输出梯度进行操作的代码
    # 函数结果可以返回修改后的张量或者None
    return output
handle = module.register_backward_hook(hook)

这个钩子的输入参数分别代表模块本身、输入的梯度和输出的梯度。注意,如果输入的参数有多个张量,输入梯度和输出梯度可能是一个元组,其中包含输入的每个张量的梯度。由于pytorch的模块可以输入张量和其他的python数据类型作为forward的参数,所以反向传播中grad_input和前向传播中的Input未必一一对应。

五、用例展示

最后用一个实例来展示一下如何使用钩子函数打印模块的输入和输出。首先定义模块前执行的钩子、模块后执行的钩子、以及梯度的钩子,然后分别在模块中注册这些钩子,接着执行前向计算。可以看到,模块在前向计算时分别调用了钩子函数,打印了输入张量的形状。然后进行反向传播计算,模块成功地打印了梯度的形状。当然,也可以把钩子函数换成更复杂的,比如检查模块输入/输出是否存在NaN, Inf这些值等,可以大大加快debug的速度。

import torch
import torch.nn as nn

def print_pre_shape(module, input):
    print('模块前钩子')
    print(module.weight.shape)
    print(input[0].shape)

def print_post_shape(module, input, output):
    print('模块后钩子')
    print(module.weight.shape)
    print(input[0].shape)
    print(output[0].shape)

def print_grad_shape(module, grad_input, grad_output):
    print('梯度钩子')
    print(module.weight.grad.shape)
    print(grad_input[0].shape)
    print(grad_output[0].shape)

conv = nn.Conv2d(16, 32, kernel_size=(3,3))
handle1 = conv.register_forward_pre_hook(print_pre_shape)
handle2 = conv.register_forward_hook(print_post_shape)
handle3 = conv.register_forward_backward_hook(print_grad_shape)

input = torch.randn(4, 16, 128, 128, requires_grad=True)

ret = conv(input)
ret = ret.sum()
ret.backward()

以上代码会输出下面的结果:

模块前钩子

torch.Size([32, 16, 3, 3])

torch.Size([4, 16, 128, 128])

模块后钩子

torch.Size([32, 16, 3, 3])

torch.Size([4, 16, 128, 128])

torch.Size([32, 126, 126])

梯度钩子

torch.Size([32, 16, 3, 3])

torch.Size([4, 16, 128, 128])

torch.Size([4, 32, 126, 126])