Scikit-Learn 源码研读 (第二期)基类的实现细节

时间:2021-10-14 07:47:58

sklearn项目可以看成一棵大树,各种estimator是果实,而支撑这些估计器的主干,是为数不多的几个基类。常见的几个类有BaseEstimator、BaseSGD、ClassifierMixin、RegressorMixin,等等。

官方文档的API参考页面列出了主要的API接口,我们看下Base类

Scikit-Learn 源码研读 (第二期)基类的实现细节

本期我们只研究BaseEstimator、ClassifierMixin、RegressorMixin、TransformerMixin。BaseSGD是一个比较大的话题,需要单独开一期来仔细研究。

BaseEstimator

最底层的就是BaseEstimator类。主要暴露两个方法:set_paramsget_params.

get_params

这个方法旨在获取对象的参数,返回对象默认是{参数:参数值}的键值对。如果将get_params的参数deep设置为True,还会返回(如果有的话)子对象(它们是估计器)。下面我们来仔细看一下这个方法的实现细节:

Scikit-Learn 源码研读 (第二期)基类的实现细节

为了节约篇幅,我会将不重要的注释略去,以后都是这样处理,不再赘述,除非特殊说明。

(1)

函数体中主要就是getattr方法,语法:getattr(对象,要检索的属性[,如果属性不存在则返回的值])。Line200~208的任务是判断self(一般就是估计器的实例)是否含有key这个参数,如果有就返回它的参数值,否则人为设置为None。

为什么要写这么复杂呢? 其实可以直接写作 value = getattr(self, key, None),有点迷~

(2)

再来看Line209~212,如果用户设置了deep=True,并且value对象实现了get_params(说明value对象是一个子对象,即估计器,否则普通的参数是不会再次实现get_params方法的),则提取参数字典的键值对,并且写入字典。整个函数最后返回的也是字典。

(3)

我们先快速的看一下这个方法具体是怎么使用的,然后再继续追踪源码的实现。

from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(random_state=0)
X = [[ 1, 2, 3], # 2 samples, 3 features
[11, 12, 13]]
y = [0, 1] # classes of each sample
clf.fit(X, y)

简单的实例化一个随机森林分类器的对象,我们看下对它调用get_params会返回什么:

clf.get_params()

{'bootstrap': True,
'class_weight': None,
'criterion': 'gini',
'max_depth': None,
'max_features': 'auto',
'max_leaf_nodes': None,
'min_impurity_decrease': 0.0,
'min_impurity_split': None,
'min_samples_leaf': 1,
'min_samples_split': 2,
'min_weight_fraction_leaf': 0.0,
'n_estimators': 10,
'n_jobs': None,
'oob_score': False,
'random_state': 0,
'verbose': 0,
'warm_start': False}

很明显,这就是这个随机森林分类器的默认参数方案。

(4)

我们注意到Line199这行,使用了另一个方法 for key in self._get_param_names():,现在研究该函数

Scikit-Learn 源码研读 (第二期)基类的实现细节

这里赘述一下,在sklearn这种大型的Python项目中,很多暴露出去的方法,其实质只是一个壳子,你可以理解为它是在搬运别人做的东西,只是美化包装一下交给调用者。例如get_params方法,它并没有真的获取到估计器实例的参数,因为_get_param_names在帮它干这个活儿。

@classmethod这个装饰器直接告诉我们,该方法的适用对象是类自身,而非实例对象。

这个函数有很多检查事项,真正获取参数的是 inspect.signature(init).parameters.values(),最后获取列表中每个对象的name属性。

set_params

这个方法作用是设置参数。正常来说,我们在初始化估计器的时候定制化参数,但是也有临时修改参数的需求,这时可以手工调用set_params方法。但是更多的还是由继承BaseEstimator的类来调用这个方法。

具体地,我们看下实现细节:

Scikit-Learn 源码研读 (第二期)基类的实现细节

这个方案支持处理嵌套字典,但是我们不去纠缠这么琐碎,直接看到L251,setattr(self, key, value),对估计器的key属性设置一个新的值。

应用的实例:

Scikit-Learn 源码研读 (第二期)基类的实现细节

ClassifierMixin

Mixin表示混入类,可以简单地理解为给其他的类增加一些额外的方法。Sklearn的分类、回归混入类只实现了score方法,任何继承它们的类需要自己去实现fitpredict等其他方法。

Scikit-Learn 源码研读 (第二期)基类的实现细节

关于混入类,简单的说就是一个父类,但是和普通的类有点不同,它需要指明元对象,_estimator_type。这里不再展开论述,感兴趣的读者请阅读这篇讨论 What is a mixin, and why are they useful?

可以看到,这个混入类的实现非常简单,求预测值和真实值的准确率,返回值是一个浮点数。注意预测值来自self.predict(),所以继承混入类的类必须自己实现predict方法,否则引发错误。后面不再重复强调该细节。

再次的,分类任务的混入类又是在搬运其它函数的劳动成果,那我们就来研究一下accuracy_score的实现细节

Scikit-Learn 源码研读 (第二期)基类的实现细节

为简洁起见,我们先忽略L185~189之间的代码,后面会有专门研究分类任务的度量方法的文章,在那里我们再仔细研究它。直接看L191,y_ture == y_pred,这是一个简单的写法,精妙在于避免了for循环,快速的检查两个对象之间每一个元素是否相等并且返回True/False。L193对score结果做一层包装。

  • L116:如果设置了normalize参数为True,则对score列表取平均值,就是预测正确的样本个数/总体个数=预测准确率
  • L118:如果有权重,则按照权重对各个样本的得分进行加权,作为最终的预测准确率
  • L121:如果没有上述两种设置,则直接返回预测正确的样本的个数。注意:sklearn默认的score方法返回预测准确率,而非预测正确的样本个数。

RegressorMixin

Scikit-Learn 源码研读 (第二期)基类的实现细节

毫不意外地,回归任务的混入类只实现了score方法,核心数学原理是 \(R^2\) 值。公式是 1-((y_true - y_pred)2)/((y_true - y_true_mean)2),直观上看,这个值是衡量预测值与真实值的偏离度与真实值自身偏离度的一个比值。 \(R^2\)最大为1,表示预测完全准确,值为0时表示模型没有任何预测能力。

Scikit-Learn 源码研读 (第二期)基类的实现细节

score方法调用了metrics模块的r2_score方法,返回值是浮点数。我们来研究下r2_score,这个函数是目前为止我们看过的最复杂的一个。因此,我们一块一块来研究。

检查传入的对象

Scikit-Learn 源码研读 (第二期)基类的实现细节

(1)检查传入对象的长度

L577调用check_consistent_length检查输入标签、输出标签、权重是不是有相同的长度。检查的方法也很简单,对每个对象计算长度,然后取不同的长度值有多少个,如果超过1个,说明几个对象之间的长度不一,则引发一个错误来警告。

Scikit-Learn 源码研读 (第二期)基类的实现细节

(2)检查传入的参数是否合法

L575调用_check_reg_targets方法,旨在检查传入参数是否合法。

Scikit-Learn 源码研读 (第二期)基类的实现细节

这个函数略长,但是大致做了以下几件事:

  • L83~95都是在做检查和格式转换。
  • L97~114检查输入multioutputy_true是否吻合,即真实的标签数组的维度如果是1的话,显然设置multioutput这个参数非None是不合法的。并且当真实标签数组的维度大于1的时候,若其维度和multioutput不同时也会引发错误以告警。
  • L115根据y_true的维度决定标签是哪种类型,分为:连续型和多类输出的连续型。

    注意:multioutput可以是字符串,也可以是一个数组,还可以是None值(考虑到向下兼容),因此这个参数非常灵活。后面研究具体算法时遇到了会再次提及,此处不作过多纠缠。

检查样本数和权重系数

继续看r2_score的实现:

Scikit-Learn 源码研读 (第二期)基类的实现细节

(3)L597~582检查预测值的样本数

如果预测值的样本数不足2个,则引发错误告警。因为决定系数(即\(R^2\))要求至少要有2个样本

(4)L584~588处理权重系数

  • L585调用np.ravel(),把权重数组拉平到一维
  • L586对sample_weights扩维,将一维扩充为二维,二维扩充为三维,以此类推。值得注意的是,np.newaxis放置的位置不同,扩充的方向是不同的,具体看下面这个小例子:

    Scikit-Learn 源码研读 (第二期)基类的实现细节
  • L588,如果没有传入权重系数,则默认设置为1

实现\(R^2\)的计算细节

(5)构造分子和分母

Scikit-Learn 源码研读 (第二期)基类的实现细节

(6)计算每个样本的得分

Scikit-Learn 源码研读 (第二期)基类的实现细节

  • L595~596 记录分母和分子的数组中不为0的索引值(就是非0值所在的位置)
  • L597 记录分子、分母同时不为0的样本的索引值。如果对这个写法不熟悉,这里有个小例子帮助理解:

    Scikit-Learn 源码研读 (第二期)基类的实现细节
  • L598~599 创建一个和真实标签相同长度的全1数组,然后对合法的索引位置计算真实的\(R^2\)值。
  • L603 将分母为0的索引位置的值设置为0,这里设为其他常数也是可以的,对于同一个回归任务的评价没有影响。

(7)根据multioutput参数来决定各样本所得分数的权重

Scikit-Learn 源码研读 (第二期)基类的实现细节

  • L605~607 如果指明raw_values,则输出每个样本的分数
  • L608~610 如果指明uniform_average,则avg_weights设置为None,其实就是均匀分布权重
  • L611~612 如果指明variance_weighted,则直接用分母作权重
  • L614~618 处理常量y值或一维数组的情形。如果分母全是0,则:若分子有非0,直接返回1;否则返回0
  • L620 如果multioutput不是字符串,则直接把它作为最后的权重系数

(8)返回得分

return np.average(output_scores, weights=avg_weights)

刚刚说到,指明uniform_average,则avg_weights设置为None。在numpy.average这个方法里,如果权重是None,计算均值就是简单的mean()函数。

TransformerMixin

Scikit-Learn 源码研读 (第二期)基类的实现细节

这个混入类的实现比较简单,完全依靠使用它的类自己实现的fit方法和transform方法。但是它会根据是否有标签,决定是有监督任务还是无监督任务。等后面遇到再具体讨论。

补充

Scikit-Learn 源码研读 (第二期)基类的实现细节

我们在研究分类混入类和回归混入类的时候,都发现有_estimator_type这个变量,它的具体作用就是这里看到的,判断一个估计器是用于分类任务还是回归任务的。


如果有任何纰漏差错,欢迎评论互动。

Scikit-Learn 源码研读 (第二期)基类的实现细节

Scikit-Learn 源码研读 (第二期)基类的实现细节的更多相关文章

  1. 搭建Spark源码研读和代码调试的开发环境

    转载自https://github.com/linbojin/spark-notes/blob/master/ide-setup.md 搭建Spark源码研读和代码调试的开发环境 工欲善其事,必先利其 ...

  2. 21 BasicTaskScheduler基本任务调度器(一)——Live555源码阅读(一)任务调度相关类

    21_BasicTaskScheduler基本任务调度器(一)——Live555源码阅读(一)任务调度相关类 BasicTaskScheduler基本任务调度器 BasicTaskScheduler基 ...

  3. JDK1.8源码阅读笔记(1)Object类

    JDK1.8源码阅读笔记(1)Object类 ​ Object 类属于 java.lang 包,此包下的所有类在使⽤时⽆需⼿动导⼊,系统会在程序编译期间⾃动 导⼊.Object 类是所有类的基类,当⼀ ...

  4. 12 哈希表相关类——Live555源码阅读(一)基本组件类

    12 哈希表相关类--Live555源码阅读(一)基本组件类 这是Live555源码阅读的第一部分,包括了时间类,延时队列类,处理程序描述类,哈希表类这四个大类. 本文由乌合之众 lym瞎编,欢迎转载 ...

  5. Mybatis源码解析(三) —— Mapper代理类的生成

    Mybatis源码解析(三) -- Mapper代理类的生成   在本系列第一篇文章已经讲述过在Mybatis-Spring项目中,是通过 MapperFactoryBean 的 getObject( ...

  6. JDK1.8源码(四)——java.util.Arrays类

    一.概述 1.介绍 Arrays 类是 JDK1.2 提供的一个工具类,提供处理数组的各种方法,基本上都是静态方法,能直接通过类名Arrays调用. 二.类源码 1.asList()方法 将一个泛型数 ...

  7. 19 BasicTaskScheduler0 基本任务调度类基类(一)——Live555源码阅读(一)任务调度相关类

    这是Live555源码阅读的第二部分,包括了任务调度相关的三个类.任务调度是Live555源码中很重要的部分. 本文由乌合之众 lym瞎编,欢迎转载 http://www.cnblogs.com/ol ...

  8. 18 TaskScheduler任务调度器抽象基类——Live555源码阅读(一)任务调度相关类

    这是Live555源码阅读的第二部分,包括了任务调度相关的三个类.任务调度是Live555源码中很重要的部分. 本文由乌合之众 lym瞎编,欢迎转载 http://www.cnblogs.com/ol ...

  9. 20 BasicTaskScheduler0 基本任务调度类基类(二)——Live555源码阅读(一)任务调度相关类

    这是Live555源码阅读的第二部分,包括了任务调度相关的三个类.任务调度是Live555源码中很重要的部分. 本文由乌合之众 lym瞎编,欢迎转载 http://www.cnblogs.com/ol ...

随机推荐

  1. Dell xps 13 9350待机时总是关机的处理方法

    现象: 年初买的dell xps13 9350,最近可能是由于win10做了一些更新,每次睡眠就自动关机了,重启很多次,修改电源选项都没用, 原因分析: 在网上搜了一下,有人发现xps15 9350也 ...

  2. LEETCODE —— Populating Next Right Pointers in Each Node

    Populating Next Right Pointers in Each Node Given a binary tree struct TreeLinkNode { TreeLinkNode * ...

  3. kafka生产消息的速度跟什么有关?

    kafka的吞吐量很大,在保证带宽的情况下,网上的一些测试表明3台broker,没有replication,6个partition的情况下,一般的写入速度可以达到300MB/s.参考:kakfa测试 ...

  4. IIS 发布程序的一些心得

    1.应用程序池一般自己建立对应Framework版本的程序池,并托管管道模式为经典 2.在IIS根目录双击,右侧的“ISAPI和CGI限制” 双击打开,将自己所需要的Framework版本的限制设置为 ...

  5. PAT 1019

    1019. General Palindromic Number (20) A number that will be the same when it is written forwards or ...

  6. 洛谷P1118 数字三角形游戏

    洛谷1118 数字三角形游戏 题目描述 有这么一个游戏: 写出一个1-N的排列a[i],然后每次将相邻两个数相加,构成新的序列,再对新序列进行这样的操作,显然每次构成的序列都比上一次的序列长度少1,直 ...

  7. IOS — 关于Socket传输文件需要自定义延时或者包大小的情况

    1. 首先导入头文件 #include <stdio.h> #include <errno.h> #include <string.h> #include < ...

  8. VS2010调用VLFeat

    相比OpenCV,VLFeat的代码全是开源,并且非常重要的一点,事实上现的sift和Low的精度差点儿相同,这个团队全是码神,膜拜一下. 依照以下的网址进行安装,本人已经装上了,确实能够的. 安装參 ...

  9. 虚拟主机VPS区别:DigitalOcean优惠码

    虚拟主机与VPS区别在哪?这里从几个角度解释下: 虚拟化技术 把一*立服务器用虚拟化技术“切割”开,分配不同的硬件配置,然后分配给不同的普通用户使用,就是虚拟主机的基本思路.部分奸商为了节省成本,甚 ...

  10. SourceTree安装跳过注册

    问题描述 SourceTree是一个可视化管理版本文件的工具,但是注册的时候需要打开外国的网站. 因此要安装时需要绕过这一步才能使用. 解决方案 1. 在下面路径下创建一个accounts.json文 ...