【Feature Selection】3 - Nested Resampling

时间:2024-12-10 22:47:26

以泰坦尼克数据集建立分类机器学习任务(预测生死)

数据清洗

  1. library(mlr3verse)
  2. library(mlr3fselect)
  3. (7832)
  4. lgr::get_logger("mlr3")$set_threshold("warn")
  5. lgr::get_logger("bbotk")$set_threshold("warn")
  6. library(mlr3data)
  7. data("titanic", package = "mlr3data")
  8. titanic$age[(titanic$age)] = median(titanic$age, = TRUE)
  9. titanic$embarked[(titanic$embarked)] = "S"
  10. titanic$ticket = NULL
  11. titanic$name = NULL
  12. titanic$cabin = NULL
  13. titanic = titanic[!(titanic$survived),]

创建机器学习任务:

task = as_task_classif(titanic, target = "survived", positive = "yes")

选择模型

  1. library(mlr3learners)
  2. #logistic regression learner,
  3. #To evaluate the predictive performance, we choose a 3-fold cross-validation and the classification error as the measure.
  4. learner = lrn("classif.log_reg")
  5. resampling = rsmp("cv", folds = 3)
  6. measure = msr("")
  7. resampling$instantiate(task)

以上可以认为是全局设置,在随后的特征选择过程中,方法不尽相同。但是,task, learner, resampling, measure是相同的。terminator是何时终止,因算法而不同。

总的来讲,FSelectInstanceSingleCrit$new确定了特征筛选任务,提供机器学习任务、重抽样策略、评价指标,终止事件。查看特征筛选方法:

mlr_fselectors

下面以Sequential Forward Selection为例展示这一过程。

  1. library(mlr3verse)
  2. library(mlr3fselect)
  3. (7832)
  4. lgr::get_logger("mlr3")$set_threshold("warn")
  5. lgr::get_logger("bbotk")$set_threshold("warn")
  6. library(mlr3data)
  7. data("titanic", package = "mlr3data")
  8. titanic$age[(titanic$age)] = median(titanic$age, = TRUE)
  9. titanic$embarked[(titanic$embarked)] = "S"
  10. titanic$ticket = NULL
  11. titanic$name = NULL
  12. titanic$cabin = NULL
  13. titanic = titanic[!(titanic$survived),]
  14. task = as_task_classif(titanic, target = "survived", positive = "yes")
  15. library(mlr3learners)
  16. learner = lrn("classif.log_reg")
  17. resampling = rsmp("cv", folds = 3)
  18. measure = msr("")
  19. resampling$instantiate(task)
  20. terminator = trm("stagnation", iters = 5)
  21. instance = FSelectInstanceSingleCrit$new(
  22. task = task,
  23. learner = learner,
  24. resampling = resampling,
  25. measure = measure,
  26. terminator = terminator)
  27. fselector = fs("sequential")
  28. fselector$optimize(instance)
  29. fselector$optimization_path(instance)

嵌套重抽样用于特征选择

The graphic above illustrates nested resampling for parameter tuning with 3-fold cross-validation in the outer and 4-fold cross-validation in the inner loop.

The repeated evaluation of the model might leak information about the test sets into the model and thus leads to over-fitting and over-optimistic performance results. nested resampling uses an outer and inner resampling to separate the feature selection from the performance estimation of the model. 

以上为原理步骤展示,实际中可简化代码:

  1. # Nested resampling on Palmer Penguins data set
  2. rr = fselect_nested(
  3. fselector = fs("random_search"),
  4. task = tsk("penguins"),
  5. learner = lrn(""),
  6. inner_resampling = rsmp ("holdout"),
  7. outer_resampling = rsmp("cv", folds = 2),
  8. measure = msr(""),
  9. term_evals = 4)
  10. # Performance scores estimated on the outer resampling
  11. rr$score()
  12. # Unbiased performance of the final model trained on the full data set
  13. rr$aggregate()