Xgboost如何画出树?

时间:2024-03-24 08:11:49

对Xgboost使用了一定程度的读者,肯定会面临如何画出树这个问题,毕竟不画出树就是一个黑箱,黑箱总是难以让人放心。本篇博客完整地给出了如何画出Xgboost中的树的过程。

一、训练一个简单的Xgb模型

我们先训练一个Xgb模型。代码如下:

from sklearn.model_selection import train_test_split
from pandas import DataFrame

from xgboost.sklearn import XGBClassifier
from xgboost import plot_tree
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer


breast_cancer=load_breast_cancer()
X = breast_cancer.data
y = breast_cancer.target

X = DataFrame(X)
y = DataFrame(y)
# breast_cancer.feature_names的名字中带有空格,会报错。
X.columns = breast_cancer.feature_names
X.columns = ['l1','l2','l3','l4','l5','l6','l7','l8','l9','l10','l11','l12','l13','l14',
             'l15','l16','l17','l18','l19','l20','l21','l22','l23','l24','l25',
             'l26','l27','l28','l29','l30',]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)



clf = XGBClassifier(
    n_estimators=30,#三十棵树
    learning_rate =0.3,
    max_depth=3,
    min_child_weight=1,
    gamma=0.3,
    subsample=0.8,
    colsample_bytree=0.8,
    objective= 'binary:logistic',
    nthread=12,
    scale_pos_weight=1,
    reg_lambda=1,
    seed=27)



model_sklearn=clf.fit(X_train, y_train)
y_sklearn= clf.predict_proba(X_test)[:,1]

代码没有什么可说的,是一个最基本的调用。接下来,我们来看下如何画树。

二、画树

利用XGBoost Plotting API可以实现这个过程,来看下plot_tree这个函数。

Xgboost如何画出树?

num_trees代表画的第几颗树。rankdir='LR'代码树是从左到右画。

1)安装导入相关包: XGBoost Plotting API需要用到graphviz 和pydot,我是Win10 环境+Anaconda3,pydot直接 pip install pydot 即可。对于graphviz,也简单pip install graphviz,虽然不会报错,但是在调用plot_tree时会报错。解决方案如下:graphviz需要先下载一个windows版本的graphviz安装包,下载地址如下。然后在命令行中输入以下代码将下载的graphviz添加到系统环境变量中即可。

import os
os.environ["PATH"] += os.pathsep + 'D:/Program Files (x86)/Graphviz2.38/bin/'

2)利用plot_tree画图

plot_tree(clf, num_trees=0)
plt.show()

结果如下:

Xgboost如何画出树?

 我们可以看到,图是十分地不清晰。因此,接下来,我们主要解决2个问题:

1)如何清晰画图?

2)如何把图中节点中的特征换成数据集的特征名,而不是0,1,2这种。

对于第一个问题,有大神有很简单的方法:

plot_tree(clf, num_trees=0, fmap='xgb.fmap')
fig = plt.gcf()
fig.set_size_inches(150, 100)
#plt.show()
fig.savefig('tree.png')

这样画出的图很清晰:

Xgboost如何画出树?

同时,我们看到了在图中节点中也已经使用了特征名,这个是如何实现的呢?我们只需要加入

def ceate_feature_map(features):
    outfile = open('xgb.fmap', 'w')
    i = 0
    for feat in features:
        outfile.write('{0}\t{1}\tq\n'.format(i, feat))
        i = i + 1
    outfile.close()
'''
X_train.columns在第一段代码中也已经设置过了。
特别需要注意:列名字中不能有空格。
'''
ceate_feature_map(X_train.columns)

这个函数就是根据给定的特征名字(我直接使用了数据的列名称), 按照特定格式生成一个xgb.fmap文件, 这个文件就是XGBoost文档里面多次提到的fmap, 注意使用的时候, 直接提供文件名, 比如fmap='xgb.fmap'.在画图的时候利用plot_tree(clf, num_trees=0, fmap='xgb.fmap')即可以找到对应关系。


三、参考文献

【1】如何画XGBoost里面的决策树(decision tree)