使用钩子函数的方式提取视觉特征

时间:2024-10-03 07:21:22

通过注册钩子函数,可以在模型的计算过程中插入需要执行的任意代码片段。在视觉特征提取过程中可以根据模型的结构,将正向钩子函数注册到指定的层中,然后通过读取该层的输入或输出数据,将视觉特征提取出来。

找到目标层,可以通过模型的源码找到指定的目标层,也可以通过print函数将模型对象输出并从中选取要注册钩子函数的目标层。

import torch
import torchvision
from PIL import Image
import torchvision.transforms as T

transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

img = Image.open('./test/dog/dog.55.jpg')
img_input = transforms(img)
img_input = img_input.unsqueeze(0)
img_input = img_input.to(device)

model = torchvision.models.resnet18()               # 没加载预训练模型
in_feat_num = model.fc.in_features
model.fc = torch.nn.Linear(in_features=in_feat_num, out_features=2)   # 在resnet18-f37072fd.pth的预训练模型基础上finetune的dog/cat模型
model.load_state_dict(torch.load('resnet18-f37072fd_finetune_allLayer.pth'))   # 加载finetune之后的权重
model.to(device)
# model.eval()   # 设置为eval模式

in_list = []     # 存放输入目标层的特征
out_list = []    # 存放输出目标层的特征
def hook(module_placeholder, input, output):
    print('in', len(input))   # 输入项是传入该层的参数,元组类型,输出是1,说明该层只有一个输入
    for val in input:         # 遍历每个输入项
        print(f"input val:{val.size()}")  # 输出的形状是torch.Size([1, 512, 7, 7])
    for i in range(input[0].size()[0]):  # 遍历多batch的每个图片的特征
        in_list.append(input[0][i].cpu().numpy())         # 保存单张图片的特征
        print(f'in, {input[0][i].cpu().numpy().shape}')   # 输出特征形状(512, 7, 7)
    print('out', len(output))   # 输出项是特征张量,值等于batchsize = 1
    for i in range(output.size(0)):
        out_list.append(output[i].cpu().numpy())    # 保存单张图片的特征
        print(f'out, {output[i].cpu().numpy().shape}')  # 输出特征的形状(512,1,1)

model.avgpool.register_forward_hook(hook)
with torch.no_grad():
    y_pred = model(img_input)

print('Done.')

输出:

in 1
input val:torch.Size([1, 512, 7, 7])
in, (512, 7, 7)
out 1
out, (512, 1, 1)
Done.

需要注意的是,钩子函数的输入项核输出项内容定义并不一致。输入项是一个元组,元组中的元素个数与该层的输入参数个数一致,每个元素才是真正的特征数据,而输出项直接就是该层处理后的特征数据。