来源:/pdf/1912.
官方代码:/uizard-technologies/realmix
主要贡献:
1.在cifar10数据集上仅仅只利用每类250个标签数据实现了sota(error rate:9.79%)
2.在标签数据和无标签数据完全 mismatch的情况下,依然能够surpass baseline performance。
3.论证了realmix 能够surpass迁移学习,并且迁移学习对半监督学习有一个很好的补偿作用。
算法过程:(这篇文章相当于集成了uda+mixmatch中的一些trick)
GenerateTarget部分:
Augment(x)是为了保证训练一致性,里面包括了随机左右翻转和随机crop,就是对同一组数据同时做两次不同的增强(随机)。
Extend(x)是对无标签数据的扩充,文章中是对cifar10中当标签数据设定为每类数量为250的时候,无标签数据则扩充50倍,用的是cutout(效果最好),当然也可以用其它普通的增强技术。
该篇文章所用到的一些方法:
Mixup(该方法来自于MixMatch,对普通的mixup做了一个小的改动)
EM熵最小化+Sharpen function+EMA(指数滑动平均)(MixMatch 和UDA 都用了该操作,后面的sharpen作用主要减少对无标签数据错例的敏感性)
该博主做了很好的解释:/matrix_space/article/details/90732655
TSA(Training signal annealing出自于UDA,主要思想是从总损失中移除预测值大于设定阈值的样本损失,目的为了减轻少量标签数据过拟合造成的影响,主要有三种方式:log,exp,linear,针对的是标签数据)
详细算法移步:/daixiangzi/article/details/102989630
该篇文章还利用了 Out-of-distribution function 策略去减轻 distribution mismatch ,即标签数据和无标签数据分布不一致带来的影响:具体是针对无标签数据的,它是相当于把预测值低于 设定的超参阈值的样本损失丢弃掉,只仅仅计算预测值高于设定的超参阈值样本的对损失的贡献。下图解释:
注意:在这里解释一下什么是mismatch,举个例子:cifar10数据集中包含6个动物类和4个交通工具类。假设标签数据为6个动物类,那么0%mismatch就表示另外的4类无标签数据都为动物类,100%mismatcqh则表示另外4类无标签数据都为交工具类。
实验结果:
在cifar10和svhn数据上的实验结果
在迁移学习上的表现(可以发现realmix+迁移可以进步一步提升效果)
消融实验:
主要对两个因素做了对比实验一个是Extend(x)数据增强部分和Out Of Distribution MASK(x)部分
数据增强部分:
Simple Augs:a copy+水平翻转+随机crop
25 Augs:25 copys+水平翻转+随机crop
RealMix:50 copys+cutout
oodmask上的表现,这里仅仅只在mismatch为75%的时候,做一个简单的对比(下面还是显示有oodmask效果要更好一点)
最后,文章中也有指出oodmask超参不好调,并且也指出希望在未来把它作为一个SSL评估的一个重要标注,因为在现实中标签数据和无标签数据很多时候都是来自于不同分布的。附上原话: