【DeepSeek背后的技术】系列二:大模型知识蒸馏(Knowledge Distillation)-4 实践

时间:2025-02-05 15:50:32

以下是一个简单的模型蒸馏代码示例,使用一个预训练的ResNet-18模型作为教师模型,并使用一个简单的CNN模型作为学生模型。同时,将使用交叉熵损失函数和L2正则化项来优化学生模型的性能表现。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms

# 定义教师模型和学生模型
teacher_model = models.resnet18(pretrained=True)
student_model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(128 * 7 * 7, 10)
)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=0.01, momentum=0.9)
optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)

# 训练数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST('../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# 蒸馏过程
for epoch in range(10):
    running_loss_teacher = 0.0
    running_loss_student = 0.0
    
    for inputs, labels in trainloader:
        # 教师模型的前向传播
        outputs_teacher = teacher_model(inputs)
        loss_teacher = criterion(outputs_teacher, labels)
        running_loss_teacher += loss_teacher.item()
        
        # 学生模型的前向传播
        outputs_student = student_model(inputs)
        loss_student = criterion(outputs_student, labels) + 0.1 * torch.sum((outputs_teacher - outputs_student) ** 2)
        running_loss_student += loss_student.item()
        
        # 反向传播和参数更新
        optimizer_teacher.zero_grad()
        optimizer_student.zero_grad()
        loss_teacher.backward()
        optimizer_teacher.step()
        loss_student.backward()
        optimizer_student.step()
    
    print(f'Epoch {epoch+1}/10 \t Loss Teacher: {running_loss_teacher / len(trainloader)} \t Loss Student: {running_loss_student / len(trainloader)}')

在这个示例中:
(1)首先定义了教师模型和学生模型,并初始化了相应的损失函数和优化器;
(2)然后,加载了MNIST手写数字数据集,并对其进行了预处理;
(3)接下来,进入蒸馏过程:对于每个批次的数据,首先使用教师模型进行前向传播并计算损失函数值;然后使用学生模型进行前向传播并计算损失函数值(同时加入了L2正则化项以鼓励学生模型学习教师模型的输出);
(4)最后,对损失函数值进行反向传播和参数更新:打印了每个批次的损失函数值以及每个epoch的平均损失函数值。
通过多次迭代训练后,我们可以得到一个性能较好且轻量化的学生模型。