【论文复现】半监督学习与数据增强-核心逻辑

时间:2024-11-27 14:39:58

具体的核心逻辑如下所示:

for epoch in range(epochs):
    model.train()
    train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)
    for labeled_batch, unlabeled_batch in train_tqdm:
        optimizer.zero_grad()
        # 利用标记样本计算损失
        data = labeled_batch[0].to(device)
        labels = labeled_batch[1].to(device)
        logits = model(normalize(strong_aug(data)))
        loss = F.cross_entropy(logits, labels)
        # 计算未标记样本伪标签
        with torch.no_grad():
            data = unlabeled_batch[0].to(device)
            logits = model(normalize(weak_aug(data)))
            probs = F.softmax(logits, dim=-1)
            trusted = torch.max(probs, dim=-1).values > threshold
            pseudo_labels = torch.argmax(probs[trusted], dim=-1)
            loss_factor = weight * torch.sum(trusted).item() / data.shape[0]
        # 利用未标记样本计算损失
        logits = model(normalize(strong_aug(data[trusted])))
        loss += loss_factor * F.cross_entropy(logits, pseudo_labels)
        # 反向梯度传播并更新模型参数
        loss.backward()
        optimizer.step()

以上代码仅作展示,更详细的代码文件请参见附件。