具体的核心逻辑如下所示:
for epoch in range(epochs):
model.train()
train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)
for labeled_batch, unlabeled_batch in train_tqdm:
optimizer.zero_grad()
# 利用标记样本计算损失
data = labeled_batch[0].to(device)
labels = labeled_batch[1].to(device)
logits = model(normalize(strong_aug(data)))
loss = F.cross_entropy(logits, labels)
# 计算未标记样本伪标签
with torch.no_grad():
data = unlabeled_batch[0].to(device)
logits = model(normalize(weak_aug(data)))
probs = F.softmax(logits, dim=-1)
trusted = torch.max(probs, dim=-1).values > threshold
pseudo_labels = torch.argmax(probs[trusted], dim=-1)
loss_factor = weight * torch.sum(trusted).item() / data.shape[0]
# 利用未标记样本计算损失
logits = model(normalize(strong_aug(data[trusted])))
loss += loss_factor * F.cross_entropy(logits, pseudo_labels)
# 反向梯度传播并更新模型参数
loss.backward()
optimizer.step()
以上代码仅作展示,更详细的代码文件请参见附件。