浅谈“知识蒸馏”技术在机器学习领域的应用

时间:2024-04-10 07:21:38

什么是知识蒸馏技术?

知识蒸馏技术首次出现是在Hinton几年前的一篇论文《Distilling the Knowledge in a Neural Network》。老大爷这么大岁数了还孜孜不倦的发明各种人工智能领域新名词,让我这种小白有很多可以去学习了解的内容,给个赞。

那什么是知识蒸馏技术呢?知识蒸馏技术的前提是将模型看作一个黑盒,数据进入后经过处理得到输出。通常意义上,复杂的模型的输出会比简单模型准确,那么是否有办法让复杂模型的知识传递给简单模型,就是知识蒸馏要探索的内容。

这就有点类似于迁移学习的原理,在迁移学习中,网络先学习大数量级的数据,然后生成base模型,再用小数据在base模型的基础上Fine-Tune。在知识蒸馏中,也是先生成复杂的Teacher模型,然后采用Teacher模型将知识传递给简单的Student模型的方式。

浅谈“知识蒸馏”技术在机器学习领域的应用

这样做的好处就是Student网络不必那么复杂,某种意义上实现了模型压缩的功能。

为什么叫蒸馏呢?

我最先好奇的其实不是Teacher网络和Student网络怎么传递知识,而是为什么用了Distilling这个词,我甚至觉得是不是某些人翻译错误了。于是有道了一下,蒸馏也可以是提炼的意思,我就懂了。

浅谈“知识蒸馏”技术在机器学习领域的应用

 

在化学领域有一个概念叫沸点,不同液体有不同的沸点,假设酒精和水混合在一起,我们想提取混合物中的水,就可以将温度加热到小于水的沸点而大于酒精沸点的温度,这样酒精就挥发了。知识蒸馏也是用相似的手段将需要的知识从Teacher网络蒸馏出来传递给Student网络。

Teacher网络和Student网络

具体怎么做呢?就是先构建一个非常复杂的网络作为Teacher网络,默认它的模型预测准确性很高。然后再构建一个简单的Student网络,用Teacher网络的输出结果q和Student网络的输出结果p做Cross Entropy(交叉熵),y是真实的目标值,最终算Loss的公式如下。

浅谈“知识蒸馏”技术在机器学习领域的应用

这样就达到了知识传递的问题,但是第二个问题来了。如果Teacher网络的预测准确率很准,比如Teacher网络是一个图片识别模型,识别猫、狗、兔子,Teacher网络很准的话,最后的输出可能是以下这样的概率分布结果,非常不均匀

  1. 猫的概率:0.998

  2. 兔子的概率:0.0013

  3. 狗的概率:0.0007

这种结果被称为Hard Label,因为真实传递下去的知识只有“猫”这一个结果,忽略了“兔子”比“狗”更像“猫”这样的知识,因为“兔子”和“狗”的权重太低。Teacher网络需要Soft Label,怎么做呢?在Softmax结果加入下面的公式:

浅谈“知识蒸馏”技术在机器学习领域的应用

其中zi是Softmax输出的logit,T取1,那么这个公式就是输出Hard Label。如果T的取值大于1,T越大整个Label的分布就变得越均匀,Hard Label就自然转变成了Soft Label。

总结

上一篇文章讲了模型压缩技术中的剪枝、量化、共享权重,加上今天这篇知识蒸馏就比较完整了。感觉知识蒸馏这种方案比较适合终端设备的模型压缩,特别是CV相关的模型。

非技术背景,纯YY,有不对的请大神指正。

参考文章:

(1)https://zhuanlan.zhihu.com/p/90049906

(2)https://blog.csdn.net/nature553863/article/details/80568658

(3)https://zhuanlan.zhihu.com/p/81467832

(衷心感谢以上文章的作者们)