获取R中rpart / ctree包的每行预测数据集的决策树规则/路径模式

时间:2021-09-21 19:37:11

I have built a decision tree model in R using rpart and ctree. I also have predicted a new dataset using the built model and got predicted probabilities and classes.

我使用rpart和ctree在R中构建了一个决策树模型。我还使用构建的模型预测了一个新的数据集,并得到了预测的概率和类。

However, I would like to extract the rule/path, in a single string, for every observation (in predicted dataset) has followed. Storing this data in tabular format, I can explain prediction with reason in a automated manner without opening R.

但是,我想在单个字符串中提取规则/路径,以跟踪每个观察(在预测数据集中)。以表格格式存储这些数据,我可以自动解释预测,而无需打开R.

Which means I want to got following.

这意味着我想得到以下。

ObsID   Probability   PredictedClass   PathFollowed 
    1          0.68             Safe   CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
    2          0.76             Safe   CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
    3          0.88           Unsafe   CarAge > 10 & Type = Van & Country = USA & Price > 15988

Kind of code I'm looking for is

我正在寻找的代码类型是

library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)

this is what needs to expanded into multiple lines possibly

这可能需要扩展到多行

predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")

1 个解决方案

#1


5  

The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

partykit包有一个函数.list.rules.party(),它目前是未导出的,但可以用来做你想做的事情。我们还没有导出它的主要原因是它的输出类型可能会在未来的版本中发生变化。

To obtain the predictions you describe above you can do:

要获得您在上面描述的预测,您可以:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

Illustration using the iris data and rpart():

使用虹膜数据和rpart()的插图:

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

(为简洁起见,仅显示每个物种的第一次观察。这对应于索引1,51和101.)

And with ctree():

并使用ctree():

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7

#1


5  

The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

partykit包有一个函数.list.rules.party(),它目前是未导出的,但可以用来做你想做的事情。我们还没有导出它的主要原因是它的输出类型可能会在未来的版本中发生变化。

To obtain the predictions you describe above you can do:

要获得您在上面描述的预测,您可以:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

Illustration using the iris data and rpart():

使用虹膜数据和rpart()的插图:

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

(为简洁起见,仅显示每个物种的第一次观察。这对应于索引1,51和101.)

And with ctree():

并使用ctree():

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7