Spark2 Random Forests 随机森林

时间:2022-05-19 15:14:51

  随机森林是决策树的集合。 随机森林结合许多决策树,以减少过度拟合的风险。 spark.ml实现支持随机森林,使用连续和分类特征,做二分类和多分类以及回归。

导入包

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.DataFrameStatFunctions
import org.apache.spark.sql.functions._ import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.{ IndexToString, StringIndexer, VectorIndexer }
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{ RandomForestClassificationModel, RandomForestClassifier }
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }

导入源数据

// affairs:一年来婚外情的频率
// gender:性别
// age:年龄
// yearsmarried:婚龄
// children:是否有小孩
// religiousness:宗教信仰程度(5分制,1分表示反对,5分表示非常信仰)
// education:学历
// occupation:职业(逆向编号的戈登7种分类)
// rating:对婚姻的自我评分(5分制,1表示非常不幸福,5表示非常幸福) val spark = SparkSession.builder().appName("Spark Random Forest Classifier").config("spark.some.config.option", "some-value").getOrCreate() // For implicit conversions like converting RDDs to DataFrames
import spark.implicits._ val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List(
(0, "male", 37, 10, "no", 3, 18, 7, 4),
(0, "female", 27, 4, "no", 4, 14, 6, 4),
(0, "female", 32, 15, "yes", 1, 12, 1, 4),
(0, "male", 57, 15, "yes", 5, 18, 6, 5),
(0, "male", 22, 0.75, "no", 2, 17, 6, 3),
(0, "female", 32, 1.5, "no", 2, 17, 5, 5),
(0, "female", 22, 0.75, "no", 2, 12, 1, 3),
(0, "male", 57, 15, "yes", 2, 14, 4, 4),
(0, "female", 32, 15, "yes", 4, 16, 1, 2),
(0, "male", 22, 1.5, "no", 4, 14, 4, 5),
(0, "male", 37, 15, "yes", 2, 20, 7, 2),
(0, "male", 27, 4, "yes", 4, 18, 6, 4),
(0, "male", 47, 15, "yes", 5, 17, 6, 4),
(0, "female", 22, 1.5, "no", 2, 17, 5, 4),
(0, "female", 27, 4, "no", 4, 14, 5, 4),
(0, "female", 37, 15, "yes", 1, 17, 5, 5),
(0, "female", 37, 15, "yes", 2, 18, 4, 3),
(0, "female", 22, 0.75, "no", 3, 16, 5, 4),
(0, "female", 22, 1.5, "no", 2, 16, 5, 5),
(0, "female", 27, 10, "yes", 2, 14, 1, 5),
(0, "female", 22, 1.5, "no", 2, 16, 5, 5),
(0, "female", 22, 1.5, "no", 2, 16, 5, 5),
(0, "female", 27, 10, "yes", 4, 16, 5, 4),
(0, "female", 32, 10, "yes", 3, 14, 1, 5),
(0, "male", 37, 4, "yes", 2, 20, 6, 4),
(0, "female", 22, 1.5, "no", 2, 18, 5, 5),
(0, "female", 27, 7, "no", 4, 16, 1, 5),
(0, "male", 42, 15, "yes", 5, 20, 6, 4),
(0, "male", 27, 4, "yes", 3, 16, 5, 5),
(0, "female", 27, 4, "yes", 3, 17, 5, 4),
(0, "male", 42, 15, "yes", 4, 20, 6, 3),
(0, "female", 22, 1.5, "no", 3, 16, 5, 5),
(0, "male", 27, 0.417, "no", 4, 17, 6, 4),
(0, "female", 42, 15, "yes", 5, 14, 5, 4),
(0, "male", 32, 4, "yes", 1, 18, 6, 4),
(0, "female", 22, 1.5, "no", 4, 16, 5, 3),
(0, "female", 42, 15, "yes", 3, 12, 1, 4),
(0, "female", 22, 4, "no", 4, 17, 5, 5),
(0, "male", 22, 1.5, "yes", 1, 14, 3, 5),
(0, "female", 22, 0.75, "no", 3, 16, 1, 5),
(0, "male", 32, 10, "yes", 5, 20, 6, 5),
(0, "male", 52, 15, "yes", 5, 18, 6, 3),
(0, "female", 22, 0.417, "no", 5, 14, 1, 4),
(0, "female", 27, 4, "yes", 2, 18, 6, 1),
(0, "female", 32, 7, "yes", 5, 17, 5, 3),
(0, "male", 22, 4, "no", 3, 16, 5, 5),
(0, "female", 27, 7, "yes", 4, 18, 6, 5),
(0, "female", 42, 15, "yes", 2, 18, 5, 4),
(0, "male", 27, 1.5, "yes", 4, 16, 3, 5),
(0, "male", 42, 15, "yes", 2, 20, 6, 4),
(0, "female", 22, 0.75, "no", 5, 14, 3, 5),
(0, "male", 32, 7, "yes", 2, 20, 6, 4),
(0, "male", 27, 4, "yes", 5, 20, 6, 5),
(0, "male", 27, 10, "yes", 4, 20, 6, 4),
(0, "male", 22, 4, "no", 1, 18, 5, 5),
(0, "female", 37, 15, "yes", 4, 14, 3, 1),
(0, "male", 22, 1.5, "yes", 5, 16, 4, 4),
(0, "female", 37, 15, "yes", 4, 17, 1, 5),
(0, "female", 27, 0.75, "no", 4, 17, 5, 4),
(0, "male", 32, 10, "yes", 4, 20, 6, 4),
(0, "female", 47, 15, "yes", 5, 14, 7, 2),
(0, "male", 37, 10, "yes", 3, 20, 6, 4),
(0, "female", 22, 0.75, "no", 2, 16, 5, 5),
(0, "male", 27, 4, "no", 2, 18, 4, 5),
(0, "male", 32, 7, "no", 4, 20, 6, 4),
(0, "male", 42, 15, "yes", 2, 17, 3, 5),
(0, "male", 37, 10, "yes", 4, 20, 6, 4),
(0, "female", 47, 15, "yes", 3, 17, 6, 5),
(0, "female", 22, 1.5, "no", 5, 16, 5, 5),
(0, "female", 27, 1.5, "no", 2, 16, 6, 4),
(0, "female", 27, 4, "no", 3, 17, 5, 5),
(0, "female", 32, 10, "yes", 5, 14, 4, 5),
(0, "female", 22, 0.125, "no", 2, 12, 5, 5),
(0, "male", 47, 15, "yes", 4, 14, 4, 3),
(0, "male", 32, 15, "yes", 1, 14, 5, 5),
(0, "male", 27, 7, "yes", 4, 16, 5, 5),
(0, "female", 22, 1.5, "yes", 3, 16, 5, 5),
(0, "male", 27, 4, "yes", 3, 17, 6, 5),
(0, "female", 22, 1.5, "no", 3, 16, 5, 5),
(0, "male", 57, 15, "yes", 2, 14, 7, 2),
(0, "male", 17.5, 1.5, "yes", 3, 18, 6, 5),
(0, "male", 57, 15, "yes", 4, 20, 6, 5),
(0, "female", 22, 0.75, "no", 2, 16, 3, 4),
(0, "male", 42, 4, "no", 4, 17, 3, 3),
(0, "female", 22, 1.5, "yes", 4, 12, 1, 5),
(0, "female", 22, 0.417, "no", 1, 17, 6, 4),
(0, "female", 32, 15, "yes", 4, 17, 5, 5),
(0, "female", 27, 1.5, "no", 3, 18, 5, 2),
(0, "female", 22, 1.5, "yes", 3, 14, 1, 5),
(0, "female", 37, 15, "yes", 3, 14, 1, 4),
(0, "female", 32, 15, "yes", 4, 14, 3, 4),
(0, "male", 37, 10, "yes", 2, 14, 5, 3),
(0, "male", 37, 10, "yes", 4, 16, 5, 4),
(0, "male", 57, 15, "yes", 5, 20, 5, 3),
(0, "male", 27, 0.417, "no", 1, 16, 3, 4),
(0, "female", 42, 15, "yes", 5, 14, 1, 5),
(0, "male", 57, 15, "yes", 3, 16, 6, 1),
(0, "male", 37, 10, "yes", 1, 16, 6, 4),
(0, "male", 37, 15, "yes", 3, 17, 5, 5),
(0, "male", 37, 15, "yes", 4, 20, 6, 5),
(0, "female", 27, 10, "yes", 5, 14, 1, 5),
(0, "male", 37, 10, "yes", 2, 18, 6, 4),
(0, "female", 22, 0.125, "no", 4, 12, 4, 5),
(0, "male", 57, 15, "yes", 5, 20, 6, 5),
(0, "female", 37, 15, "yes", 4, 18, 6, 4),
(0, "male", 22, 4, "yes", 4, 14, 6, 4),
(0, "male", 27, 7, "yes", 4, 18, 5, 4),
(0, "male", 57, 15, "yes", 4, 20, 5, 4),
(0, "male", 32, 15, "yes", 3, 14, 6, 3),
(0, "female", 22, 1.5, "no", 2, 14, 5, 4),
(0, "female", 32, 7, "yes", 4, 17, 1, 5),
(0, "female", 37, 15, "yes", 4, 17, 6, 5),
(0, "female", 32, 1.5, "no", 5, 18, 5, 5),
(0, "male", 42, 10, "yes", 5, 20, 7, 4),
(0, "female", 27, 7, "no", 3, 16, 5, 4),
(0, "male", 37, 15, "no", 4, 20, 6, 5),
(0, "male", 37, 15, "yes", 4, 14, 3, 2),
(0, "male", 32, 10, "no", 5, 18, 6, 4),
(0, "female", 22, 0.75, "no", 4, 16, 1, 5),
(0, "female", 27, 7, "yes", 4, 12, 2, 4),
(0, "female", 27, 7, "yes", 2, 16, 2, 5),
(0, "female", 42, 15, "yes", 5, 18, 5, 4),
(0, "male", 42, 15, "yes", 4, 17, 5, 3),
(0, "female", 27, 7, "yes", 2, 16, 1, 2),
(0, "female", 22, 1.5, "no", 3, 16, 5, 5),
(0, "male", 37, 15, "yes", 5, 20, 6, 5),
(0, "female", 22, 0.125, "no", 2, 14, 4, 5),
(0, "male", 27, 1.5, "no", 4, 16, 5, 5),
(0, "male", 32, 1.5, "no", 2, 18, 6, 5),
(0, "male", 27, 1.5, "no", 2, 17, 6, 5),
(0, "female", 27, 10, "yes", 4, 16, 1, 3),
(0, "male", 42, 15, "yes", 4, 18, 6, 5),
(0, "female", 27, 1.5, "no", 2, 16, 6, 5),
(0, "male", 27, 4, "no", 2, 18, 6, 3),
(0, "female", 32, 10, "yes", 3, 14, 5, 3),
(0, "female", 32, 15, "yes", 3, 18, 5, 4),
(0, "female", 22, 0.75, "no", 2, 18, 6, 5),
(0, "female", 37, 15, "yes", 2, 16, 1, 4),
(0, "male", 27, 4, "yes", 4, 20, 5, 5),
(0, "male", 27, 4, "no", 1, 20, 5, 4),
(0, "female", 27, 10, "yes", 2, 12, 1, 4),
(0, "female", 32, 15, "yes", 5, 18, 6, 4),
(0, "male", 27, 7, "yes", 5, 12, 5, 3),
(0, "male", 52, 15, "yes", 2, 18, 5, 4),
(0, "male", 27, 4, "no", 3, 20, 6, 3),
(0, "male", 37, 4, "yes", 1, 18, 5, 4),
(0, "male", 27, 4, "yes", 4, 14, 5, 4),
(0, "female", 52, 15, "yes", 5, 12, 1, 3),
(0, "female", 57, 15, "yes", 4, 16, 6, 4),
(0, "male", 27, 7, "yes", 1, 16, 5, 4),
(0, "male", 37, 7, "yes", 4, 20, 6, 3),
(0, "male", 22, 0.75, "no", 2, 14, 4, 3),
(0, "male", 32, 4, "yes", 2, 18, 5, 3),
(0, "male", 37, 15, "yes", 4, 20, 6, 3),
(0, "male", 22, 0.75, "yes", 2, 14, 4, 3),
(0, "male", 42, 15, "yes", 4, 20, 6, 3),
(0, "female", 52, 15, "yes", 5, 17, 1, 1),
(0, "female", 37, 15, "yes", 4, 14, 1, 2),
(0, "male", 27, 7, "yes", 4, 14, 5, 3),
(0, "male", 32, 4, "yes", 2, 16, 5, 5),
(0, "female", 27, 4, "yes", 2, 18, 6, 5),
(0, "female", 27, 4, "yes", 2, 18, 5, 5),
(0, "male", 37, 15, "yes", 5, 18, 6, 5),
(0, "female", 47, 15, "yes", 5, 12, 5, 4),
(0, "female", 32, 10, "yes", 3, 17, 1, 4),
(0, "female", 27, 1.5, "yes", 4, 17, 1, 2),
(0, "female", 57, 15, "yes", 2, 18, 5, 2),
(0, "female", 22, 1.5, "no", 4, 14, 5, 4),
(0, "male", 42, 15, "yes", 3, 14, 3, 4),
(0, "male", 57, 15, "yes", 4, 9, 2, 2),
(0, "male", 57, 15, "yes", 4, 20, 6, 5),
(0, "female", 22, 0.125, "no", 4, 14, 4, 5),
(0, "female", 32, 10, "yes", 4, 14, 1, 5),
(0, "female", 42, 15, "yes", 3, 18, 5, 4),
(0, "female", 27, 1.5, "no", 2, 18, 6, 5),
(0, "male", 32, 0.125, "yes", 2, 18, 5, 2),
(0, "female", 27, 4, "no", 3, 16, 5, 4),
(0, "female", 27, 10, "yes", 2, 16, 1, 4),
(0, "female", 32, 7, "yes", 4, 16, 1, 3),
(0, "female", 37, 15, "yes", 4, 14, 5, 4),
(0, "female", 42, 15, "yes", 5, 17, 6, 2),
(0, "male", 32, 1.5, "yes", 4, 14, 6, 5),
(0, "female", 32, 4, "yes", 3, 17, 5, 3),
(0, "female", 37, 7, "no", 4, 18, 5, 5),
(0, "female", 22, 0.417, "yes", 3, 14, 3, 5),
(0, "female", 27, 7, "yes", 4, 14, 1, 5),
(0, "male", 27, 0.75, "no", 3, 16, 5, 5),
(0, "male", 27, 4, "yes", 2, 20, 5, 5),
(0, "male", 32, 10, "yes", 4, 16, 4, 5),
(0, "male", 32, 15, "yes", 1, 14, 5, 5),
(0, "male", 22, 0.75, "no", 3, 17, 4, 5),
(0, "female", 27, 7, "yes", 4, 17, 1, 4),
(0, "male", 27, 0.417, "yes", 4, 20, 5, 4),
(0, "male", 37, 15, "yes", 4, 20, 5, 4),
(0, "female", 37, 15, "yes", 2, 14, 1, 3),
(0, "male", 22, 4, "yes", 1, 18, 5, 4),
(0, "male", 37, 15, "yes", 4, 17, 5, 3),
(0, "female", 22, 1.5, "no", 2, 14, 4, 5),
(0, "male", 52, 15, "yes", 4, 14, 6, 2),
(0, "female", 22, 1.5, "no", 4, 17, 5, 5),
(0, "male", 32, 4, "yes", 5, 14, 3, 5),
(0, "male", 32, 4, "yes", 2, 14, 3, 5),
(0, "female", 22, 1.5, "no", 3, 16, 6, 5),
(0, "male", 27, 0.75, "no", 2, 18, 3, 3),
(0, "female", 22, 7, "yes", 2, 14, 5, 2),
(0, "female", 27, 0.75, "no", 2, 17, 5, 3),
(0, "female", 37, 15, "yes", 4, 12, 1, 2),
(0, "female", 22, 1.5, "no", 1, 14, 1, 5),
(0, "female", 37, 10, "no", 2, 12, 4, 4),
(0, "female", 37, 15, "yes", 4, 18, 5, 3),
(0, "female", 42, 15, "yes", 3, 12, 3, 3),
(0, "male", 22, 4, "no", 2, 18, 5, 5),
(0, "male", 52, 7, "yes", 2, 20, 6, 2),
(0, "male", 27, 0.75, "no", 2, 17, 5, 5),
(0, "female", 27, 4, "no", 2, 17, 4, 5),
(0, "male", 42, 1.5, "no", 5, 20, 6, 5),
(0, "male", 22, 1.5, "no", 4, 17, 6, 5),
(0, "male", 22, 4, "no", 4, 17, 5, 3),
(0, "female", 22, 4, "yes", 1, 14, 5, 4),
(0, "male", 37, 15, "yes", 5, 20, 4, 5),
(0, "female", 37, 10, "yes", 3, 16, 6, 3),
(0, "male", 42, 15, "yes", 4, 17, 6, 5),
(0, "female", 47, 15, "yes", 4, 17, 5, 5),
(0, "male", 22, 1.5, "no", 4, 16, 5, 4),
(0, "female", 32, 10, "yes", 3, 12, 1, 4),
(0, "female", 22, 7, "yes", 1, 14, 3, 5),
(0, "female", 32, 10, "yes", 4, 17, 5, 4),
(0, "male", 27, 1.5, "yes", 2, 16, 2, 4),
(0, "male", 37, 15, "yes", 4, 14, 5, 5),
(0, "male", 42, 4, "yes", 3, 14, 4, 5),
(0, "female", 37, 15, "yes", 5, 14, 5, 4),
(0, "female", 32, 7, "yes", 4, 17, 5, 5),
(0, "female", 42, 15, "yes", 4, 18, 6, 5),
(0, "male", 27, 4, "no", 4, 18, 6, 4),
(0, "male", 22, 0.75, "no", 4, 18, 6, 5),
(0, "male", 27, 4, "yes", 4, 14, 5, 3),
(0, "female", 22, 0.75, "no", 5, 18, 1, 5),
(0, "female", 52, 15, "yes", 5, 9, 5, 5),
(0, "male", 32, 10, "yes", 3, 14, 5, 5),
(0, "female", 37, 15, "yes", 4, 16, 4, 4),
(0, "male", 32, 7, "yes", 2, 20, 5, 4),
(0, "female", 42, 15, "yes", 3, 18, 1, 4),
(0, "male", 32, 15, "yes", 1, 16, 5, 5),
(0, "male", 27, 4, "yes", 3, 18, 5, 5),
(0, "female", 32, 15, "yes", 4, 12, 3, 4),
(0, "male", 22, 0.75, "yes", 3, 14, 2, 4),
(0, "female", 22, 1.5, "no", 3, 16, 5, 3),
(0, "female", 42, 15, "yes", 4, 14, 3, 5),
(0, "female", 52, 15, "yes", 3, 16, 5, 4),
(0, "male", 37, 15, "yes", 5, 20, 6, 4),
(0, "female", 47, 15, "yes", 4, 12, 2, 3),
(0, "male", 57, 15, "yes", 2, 20, 6, 4),
(0, "male", 32, 7, "yes", 4, 17, 5, 5),
(0, "female", 27, 7, "yes", 4, 17, 1, 4),
(0, "male", 22, 1.5, "no", 1, 18, 6, 5),
(0, "female", 22, 4, "yes", 3, 9, 1, 4),
(0, "female", 22, 1.5, "no", 2, 14, 1, 5),
(0, "male", 42, 15, "yes", 2, 20, 6, 4),
(0, "male", 57, 15, "yes", 4, 9, 2, 4),
(0, "female", 27, 7, "yes", 2, 18, 1, 5),
(0, "female", 22, 4, "yes", 3, 14, 1, 5),
(0, "male", 37, 15, "yes", 4, 14, 5, 3),
(0, "male", 32, 7, "yes", 1, 18, 6, 4),
(0, "female", 22, 1.5, "no", 2, 14, 5, 5),
(0, "female", 22, 1.5, "yes", 3, 12, 1, 3),
(0, "male", 52, 15, "yes", 2, 14, 5, 5),
(0, "female", 37, 15, "yes", 2, 14, 1, 1),
(0, "female", 32, 10, "yes", 2, 14, 5, 5),
(0, "male", 42, 15, "yes", 4, 20, 4, 5),
(0, "female", 27, 4, "yes", 3, 18, 4, 5),
(0, "male", 37, 15, "yes", 4, 20, 6, 5),
(0, "male", 27, 1.5, "no", 3, 18, 5, 5),
(0, "female", 22, 0.125, "no", 2, 16, 6, 3),
(0, "male", 32, 10, "yes", 2, 20, 6, 3),
(0, "female", 27, 4, "no", 4, 18, 5, 4),
(0, "female", 27, 7, "yes", 2, 12, 5, 1),
(0, "male", 32, 4, "yes", 5, 18, 6, 3),
(0, "female", 37, 15, "yes", 2, 17, 5, 5),
(0, "male", 47, 15, "no", 4, 20, 6, 4),
(0, "male", 27, 1.5, "no", 1, 18, 5, 5),
(0, "male", 37, 15, "yes", 4, 20, 6, 4),
(0, "female", 32, 15, "yes", 4, 18, 1, 4),
(0, "female", 32, 7, "yes", 4, 17, 5, 4),
(0, "female", 42, 15, "yes", 3, 14, 1, 3),
(0, "female", 27, 7, "yes", 3, 16, 1, 4),
(0, "male", 27, 1.5, "no", 3, 16, 4, 2),
(0, "male", 22, 1.5, "no", 3, 16, 3, 5),
(0, "male", 27, 4, "yes", 3, 16, 4, 2),
(0, "female", 27, 7, "yes", 3, 12, 1, 2),
(0, "female", 37, 15, "yes", 2, 18, 5, 4),
(0, "female", 37, 7, "yes", 3, 14, 4, 4),
(0, "male", 22, 1.5, "no", 2, 16, 5, 5),
(0, "male", 37, 15, "yes", 5, 20, 5, 4),
(0, "female", 22, 1.5, "no", 4, 16, 5, 3),
(0, "female", 32, 10, "yes", 4, 16, 1, 5),
(0, "male", 27, 4, "no", 2, 17, 5, 3),
(0, "female", 22, 0.417, "no", 4, 14, 5, 5),
(0, "female", 27, 4, "no", 2, 18, 5, 5),
(0, "male", 37, 15, "yes", 4, 18, 5, 3),
(0, "male", 37, 10, "yes", 5, 20, 7, 4),
(0, "female", 27, 7, "yes", 2, 14, 4, 2),
(0, "male", 32, 4, "yes", 2, 16, 5, 5),
(0, "male", 32, 4, "yes", 2, 16, 6, 4),
(0, "male", 22, 1.5, "no", 3, 18, 4, 5),
(0, "female", 22, 4, "yes", 4, 14, 3, 4),
(0, "female", 17.5, 0.75, "no", 2, 18, 5, 4),
(0, "male", 32, 10, "yes", 4, 20, 4, 5),
(0, "female", 32, 0.75, "no", 5, 14, 3, 3),
(0, "male", 37, 15, "yes", 4, 17, 5, 3),
(0, "male", 32, 4, "no", 3, 14, 4, 5),
(0, "female", 27, 1.5, "no", 2, 17, 3, 2),
(0, "female", 22, 7, "yes", 4, 14, 1, 5),
(0, "male", 47, 15, "yes", 5, 14, 6, 5),
(0, "male", 27, 4, "yes", 1, 16, 4, 4),
(0, "female", 37, 15, "yes", 5, 14, 1, 3),
(0, "male", 42, 4, "yes", 4, 18, 5, 5),
(0, "female", 32, 4, "yes", 2, 14, 1, 5),
(0, "male", 52, 15, "yes", 2, 14, 7, 4),
(0, "female", 22, 1.5, "no", 2, 16, 1, 4),
(0, "male", 52, 15, "yes", 4, 12, 2, 4),
(0, "female", 22, 0.417, "no", 3, 17, 1, 5),
(0, "female", 22, 1.5, "no", 2, 16, 5, 5),
(0, "male", 27, 4, "yes", 4, 20, 6, 4),
(0, "female", 32, 15, "yes", 4, 14, 1, 5),
(0, "female", 27, 1.5, "no", 2, 16, 3, 5),
(0, "male", 32, 4, "no", 1, 20, 6, 5),
(0, "male", 37, 15, "yes", 3, 20, 6, 4),
(0, "female", 32, 10, "no", 2, 16, 6, 5),
(0, "female", 32, 10, "yes", 5, 14, 5, 5),
(0, "male", 37, 1.5, "yes", 4, 18, 5, 3),
(0, "male", 32, 1.5, "no", 2, 18, 4, 4),
(0, "female", 32, 10, "yes", 4, 14, 1, 4),
(0, "female", 47, 15, "yes", 4, 18, 5, 4),
(0, "female", 27, 10, "yes", 5, 12, 1, 5),
(0, "male", 27, 4, "yes", 3, 16, 4, 5),
(0, "female", 37, 15, "yes", 4, 12, 4, 2),
(0, "female", 27, 0.75, "no", 4, 16, 5, 5),
(0, "female", 37, 15, "yes", 4, 16, 1, 5),
(0, "female", 32, 15, "yes", 3, 16, 1, 5),
(0, "female", 27, 10, "yes", 2, 16, 1, 5),
(0, "male", 27, 7, "no", 2, 20, 6, 5),
(0, "female", 37, 15, "yes", 2, 14, 1, 3),
(0, "male", 27, 1.5, "yes", 2, 17, 4, 4),
(0, "female", 22, 0.75, "yes", 2, 14, 1, 5),
(0, "male", 22, 4, "yes", 4, 14, 2, 4),
(0, "male", 42, 0.125, "no", 4, 17, 6, 4),
(0, "male", 27, 1.5, "yes", 4, 18, 6, 5),
(0, "male", 27, 7, "yes", 3, 16, 6, 3),
(0, "female", 52, 15, "yes", 4, 14, 1, 3),
(0, "male", 27, 1.5, "no", 5, 20, 5, 2),
(0, "female", 27, 1.5, "no", 2, 16, 5, 5),
(0, "female", 27, 1.5, "no", 3, 17, 5, 5),
(0, "male", 22, 0.125, "no", 5, 16, 4, 4),
(0, "female", 27, 4, "yes", 4, 16, 1, 5),
(0, "female", 27, 4, "yes", 4, 12, 1, 5),
(0, "female", 47, 15, "yes", 2, 14, 5, 5),
(0, "female", 32, 15, "yes", 3, 14, 5, 3),
(0, "male", 42, 7, "yes", 2, 16, 5, 5),
(0, "male", 22, 0.75, "no", 4, 16, 6, 4),
(0, "male", 27, 0.125, "no", 3, 20, 6, 5),
(0, "male", 32, 10, "yes", 3, 20, 6, 5),
(0, "female", 22, 0.417, "no", 5, 14, 4, 5),
(0, "female", 47, 15, "yes", 5, 14, 1, 4),
(0, "female", 32, 10, "yes", 3, 14, 1, 5),
(0, "male", 57, 15, "yes", 4, 17, 5, 5),
(0, "male", 27, 4, "yes", 3, 20, 6, 5),
(0, "female", 32, 7, "yes", 4, 17, 1, 5),
(0, "female", 37, 10, "yes", 4, 16, 1, 5),
(0, "female", 32, 10, "yes", 1, 18, 1, 4),
(0, "female", 22, 4, "no", 3, 14, 1, 4),
(0, "female", 27, 7, "yes", 4, 14, 3, 2),
(0, "male", 57, 15, "yes", 5, 18, 5, 2),
(0, "male", 32, 7, "yes", 2, 18, 5, 5),
(0, "female", 27, 1.5, "no", 4, 17, 1, 3),
(0, "male", 22, 1.5, "no", 4, 14, 5, 5),
(0, "female", 22, 1.5, "yes", 4, 14, 5, 4),
(0, "female", 32, 7, "yes", 3, 16, 1, 5),
(0, "female", 47, 15, "yes", 3, 16, 5, 4),
(0, "female", 22, 0.75, "no", 3, 16, 1, 5),
(0, "female", 22, 1.5, "yes", 2, 14, 5, 5),
(0, "female", 27, 4, "yes", 1, 16, 5, 5),
(0, "male", 52, 15, "yes", 4, 16, 5, 5),
(0, "male", 32, 10, "yes", 4, 20, 6, 5),
(0, "male", 47, 15, "yes", 4, 16, 6, 4),
(0, "female", 27, 7, "yes", 2, 14, 1, 2),
(0, "female", 22, 1.5, "no", 4, 14, 4, 5),
(0, "female", 32, 10, "yes", 2, 16, 5, 4),
(0, "female", 22, 0.75, "no", 2, 16, 5, 4),
(0, "female", 22, 1.5, "no", 2, 16, 5, 5),
(0, "female", 42, 15, "yes", 3, 18, 6, 4),
(0, "female", 27, 7, "yes", 5, 14, 4, 5),
(0, "male", 42, 15, "yes", 4, 16, 4, 4),
(0, "female", 57, 15, "yes", 3, 18, 5, 2),
(0, "male", 42, 15, "yes", 3, 18, 6, 2),
(0, "female", 32, 7, "yes", 2, 14, 1, 2),
(0, "male", 22, 4, "no", 5, 12, 4, 5),
(0, "female", 22, 1.5, "no", 1, 16, 6, 5),
(0, "female", 22, 0.75, "no", 1, 14, 4, 5),
(0, "female", 32, 15, "yes", 4, 12, 1, 5),
(0, "male", 22, 1.5, "no", 2, 18, 5, 3),
(0, "male", 27, 4, "yes", 5, 17, 2, 5),
(0, "female", 27, 4, "yes", 4, 12, 1, 5),
(0, "male", 42, 15, "yes", 5, 18, 5, 4),
(0, "male", 32, 1.5, "no", 2, 20, 7, 3),
(0, "male", 57, 15, "no", 4, 9, 3, 1),
(0, "male", 37, 7, "no", 4, 18, 5, 5),
(0, "male", 52, 15, "yes", 2, 17, 5, 4),
(0, "male", 47, 15, "yes", 4, 17, 6, 5),
(0, "female", 27, 7, "no", 2, 17, 5, 4),
(0, "female", 27, 7, "yes", 4, 14, 5, 5),
(0, "female", 22, 4, "no", 2, 14, 3, 3),
(0, "male", 37, 7, "yes", 2, 20, 6, 5),
(0, "male", 27, 7, "no", 4, 12, 4, 3),
(0, "male", 42, 10, "yes", 4, 18, 6, 4),
(0, "female", 22, 1.5, "no", 3, 14, 1, 5),
(0, "female", 22, 4, "yes", 2, 14, 1, 3),
(0, "female", 57, 15, "no", 4, 20, 6, 5),
(0, "male", 37, 15, "yes", 4, 14, 4, 3),
(0, "female", 27, 7, "yes", 3, 18, 5, 5),
(0, "female", 17.5, 10, "no", 4, 14, 4, 5),
(0, "male", 22, 4, "yes", 4, 16, 5, 5),
(0, "female", 27, 4, "yes", 2, 16, 1, 4),
(0, "female", 37, 15, "yes", 2, 14, 5, 1),
(0, "female", 22, 1.5, "no", 5, 14, 1, 4),
(0, "male", 27, 7, "yes", 2, 20, 5, 4),
(0, "male", 27, 4, "yes", 4, 14, 5, 5),
(0, "male", 22, 0.125, "no", 1, 16, 3, 5),
(0, "female", 27, 7, "yes", 4, 14, 1, 4),
(0, "female", 32, 15, "yes", 5, 16, 5, 3),
(0, "male", 32, 10, "yes", 4, 18, 5, 4),
(0, "female", 32, 15, "yes", 2, 14, 3, 4),
(0, "female", 22, 1.5, "no", 3, 17, 5, 5),
(0, "male", 27, 4, "yes", 4, 17, 4, 4),
(0, "female", 52, 15, "yes", 5, 14, 1, 5),
(0, "female", 27, 7, "yes", 2, 12, 1, 2),
(0, "female", 27, 7, "yes", 3, 12, 1, 4),
(0, "female", 42, 15, "yes", 2, 14, 1, 4),
(0, "female", 42, 15, "yes", 4, 14, 5, 4),
(0, "male", 27, 7, "yes", 4, 14, 3, 3),
(0, "male", 27, 7, "yes", 2, 20, 6, 2),
(0, "female", 42, 15, "yes", 3, 12, 3, 3),
(0, "male", 27, 4, "yes", 3, 16, 3, 5),
(0, "female", 27, 7, "yes", 3, 14, 1, 4),
(0, "female", 22, 1.5, "no", 2, 14, 4, 5),
(0, "female", 27, 4, "yes", 4, 14, 1, 4),
(0, "female", 22, 4, "no", 4, 14, 5, 5),
(0, "female", 22, 1.5, "no", 2, 16, 4, 5),
(0, "male", 47, 15, "no", 4, 14, 5, 4),
(0, "male", 37, 10, "yes", 2, 18, 6, 2),
(0, "male", 37, 15, "yes", 3, 17, 5, 4),
(0, "female", 27, 4, "yes", 2, 16, 1, 4),
(3, "male", 27, 1.5, "no", 3, 18, 4, 4),
(3, "female", 27, 4, "yes", 3, 17, 1, 5),
(7, "male", 37, 15, "yes", 5, 18, 6, 2),
(12, "female", 32, 10, "yes", 3, 17, 5, 2),
(1, "male", 22, 0.125, "no", 4, 16, 5, 5),
(1, "female", 22, 1.5, "yes", 2, 14, 1, 5),
(12, "male", 37, 15, "yes", 4, 14, 5, 2),
(7, "female", 22, 1.5, "no", 2, 14, 3, 4),
(2, "male", 37, 15, "yes", 2, 18, 6, 4),
(3, "female", 32, 15, "yes", 4, 12, 3, 2),
(1, "female", 37, 15, "yes", 4, 14, 4, 2),
(7, "female", 42, 15, "yes", 3, 17, 1, 4),
(12, "female", 42, 15, "yes", 5, 9, 4, 1),
(12, "male", 37, 10, "yes", 2, 20, 6, 2),
(12, "female", 32, 15, "yes", 3, 14, 1, 2),
(3, "male", 27, 4, "no", 1, 18, 6, 5),
(7, "male", 37, 10, "yes", 2, 18, 7, 3),
(7, "female", 27, 4, "no", 3, 17, 5, 5),
(1, "male", 42, 15, "yes", 4, 16, 5, 5),
(1, "female", 47, 15, "yes", 5, 14, 4, 5),
(7, "female", 27, 4, "yes", 3, 18, 5, 4),
(1, "female", 27, 7, "yes", 5, 14, 1, 4),
(12, "male", 27, 1.5, "yes", 3, 17, 5, 4),
(12, "female", 27, 7, "yes", 4, 14, 6, 2),
(3, "female", 42, 15, "yes", 4, 16, 5, 4),
(7, "female", 27, 10, "yes", 4, 12, 7, 3),
(1, "male", 27, 1.5, "no", 2, 18, 5, 2),
(1, "male", 32, 4, "no", 4, 20, 6, 4),
(1, "female", 27, 7, "yes", 3, 14, 1, 3),
(3, "female", 32, 10, "yes", 4, 14, 1, 4),
(3, "male", 27, 4, "yes", 2, 18, 7, 2),
(1, "female", 17.5, 0.75, "no", 5, 14, 4, 5),
(1, "female", 32, 10, "yes", 4, 18, 1, 5),
(7, "female", 32, 7, "yes", 2, 17, 6, 4),
(7, "male", 37, 15, "yes", 2, 20, 6, 4),
(7, "female", 37, 10, "no", 1, 20, 5, 3),
(12, "female", 32, 10, "yes", 2, 16, 5, 5),
(7, "male", 52, 15, "yes", 2, 20, 6, 4),
(7, "female", 42, 15, "yes", 1, 12, 1, 3),
(1, "male", 52, 15, "yes", 2, 20, 6, 3),
(2, "male", 37, 15, "yes", 3, 18, 6, 5),
(12, "female", 22, 4, "no", 3, 12, 3, 4),
(12, "male", 27, 7, "yes", 1, 18, 6, 2),
(1, "male", 27, 4, "yes", 3, 18, 5, 5),
(12, "male", 47, 15, "yes", 4, 17, 6, 5),
(12, "female", 42, 15, "yes", 4, 12, 1, 1),
(7, "male", 27, 4, "no", 3, 14, 3, 4),
(7, "female", 32, 7, "yes", 4, 18, 4, 5),
(1, "male", 32, 0.417, "yes", 3, 12, 3, 4),
(3, "male", 47, 15, "yes", 5, 16, 5, 4),
(12, "male", 37, 15, "yes", 2, 20, 5, 4),
(7, "male", 22, 4, "yes", 2, 17, 6, 4),
(1, "male", 27, 4, "no", 2, 14, 4, 5),
(7, "female", 52, 15, "yes", 5, 16, 1, 3),
(1, "male", 27, 4, "no", 3, 14, 3, 3),
(1, "female", 27, 10, "yes", 4, 16, 1, 4),
(1, "male", 32, 7, "yes", 3, 14, 7, 4),
(7, "male", 32, 7, "yes", 2, 18, 4, 1),
(3, "male", 22, 1.5, "no", 1, 14, 3, 2),
(7, "male", 22, 4, "yes", 3, 18, 6, 4),
(7, "male", 42, 15, "yes", 4, 20, 6, 4),
(2, "female", 57, 15, "yes", 1, 18, 5, 4),
(7, "female", 32, 4, "yes", 3, 18, 5, 2),
(1, "male", 27, 4, "yes", 1, 16, 4, 4),
(7, "male", 32, 7, "yes", 4, 16, 1, 4),
(2, "male", 57, 15, "yes", 1, 17, 4, 4),
(7, "female", 42, 15, "yes", 4, 14, 5, 2),
(7, "male", 37, 10, "yes", 1, 18, 5, 3),
(3, "male", 42, 15, "yes", 3, 17, 6, 1),
(1, "female", 52, 15, "yes", 3, 14, 4, 4),
(2, "female", 27, 7, "yes", 3, 17, 5, 3),
(12, "male", 32, 7, "yes", 2, 12, 4, 2),
(1, "male", 22, 4, "no", 4, 14, 2, 5),
(3, "male", 27, 7, "yes", 3, 18, 6, 4),
(12, "female", 37, 15, "yes", 1, 18, 5, 5),
(7, "female", 32, 15, "yes", 3, 17, 1, 3),
(7, "female", 27, 7, "no", 2, 17, 5, 5),
(1, "female", 32, 7, "yes", 3, 17, 5, 3),
(1, "male", 32, 1.5, "yes", 2, 14, 2, 4),
(12, "female", 42, 15, "yes", 4, 14, 1, 2),
(7, "male", 32, 10, "yes", 3, 14, 5, 4),
(7, "male", 37, 4, "yes", 1, 20, 6, 3),
(1, "female", 27, 4, "yes", 2, 16, 5, 3),
(12, "female", 42, 15, "yes", 3, 14, 4, 3),
(1, "male", 27, 10, "yes", 5, 20, 6, 5),
(12, "male", 37, 10, "yes", 2, 20, 6, 2),
(12, "female", 27, 7, "yes", 1, 14, 3, 3),
(3, "female", 27, 7, "yes", 4, 12, 1, 2),
(3, "male", 32, 10, "yes", 2, 14, 4, 4),
(12, "female", 17.5, 0.75, "yes", 2, 12, 1, 3),
(12, "female", 32, 15, "yes", 3, 18, 5, 4),
(2, "female", 22, 7, "no", 4, 14, 4, 3),
(1, "male", 32, 7, "yes", 4, 20, 6, 5),
(7, "male", 27, 4, "yes", 2, 18, 6, 2),
(1, "female", 22, 1.5, "yes", 5, 14, 5, 3),
(12, "female", 32, 15, "no", 3, 17, 5, 1),
(12, "female", 42, 15, "yes", 2, 12, 1, 2),
(7, "male", 42, 15, "yes", 3, 20, 5, 4),
(12, "male", 32, 10, "no", 2, 18, 4, 2),
(12, "female", 32, 15, "yes", 3, 9, 1, 1),
(7, "male", 57, 15, "yes", 5, 20, 4, 5),
(12, "male", 47, 15, "yes", 4, 20, 6, 4),
(2, "female", 42, 15, "yes", 2, 17, 6, 3),
(12, "male", 37, 15, "yes", 3, 17, 6, 3),
(12, "male", 37, 15, "yes", 5, 17, 5, 2),
(7, "male", 27, 10, "yes", 2, 20, 6, 4),
(2, "male", 37, 15, "yes", 2, 16, 5, 4),
(12, "female", 32, 15, "yes", 1, 14, 5, 2),
(7, "male", 32, 10, "yes", 3, 17, 6, 3),
(2, "male", 37, 15, "yes", 4, 18, 5, 1),
(7, "female", 27, 1.5, "no", 2, 17, 5, 5),
(3, "female", 47, 15, "yes", 2, 17, 5, 2),
(12, "male", 37, 15, "yes", 2, 17, 5, 4),
(12, "female", 27, 4, "no", 2, 14, 5, 5),
(2, "female", 27, 10, "yes", 4, 14, 1, 5),
(1, "female", 22, 4, "yes", 3, 16, 1, 3),
(12, "male", 52, 7, "no", 4, 16, 5, 5),
(2, "female", 27, 4, "yes", 1, 16, 3, 5),
(7, "female", 37, 15, "yes", 2, 17, 6, 4),
(2, "female", 27, 4, "no", 1, 17, 3, 1),
(12, "female", 17.5, 0.75, "yes", 2, 12, 3, 5),
(7, "female", 32, 15, "yes", 5, 18, 5, 4),
(7, "female", 22, 4, "no", 1, 16, 3, 5),
(2, "male", 32, 4, "yes", 4, 18, 6, 4),
(1, "female", 22, 1.5, "yes", 3, 18, 5, 2),
(3, "female", 42, 15, "yes", 2, 17, 5, 4),
(1, "male", 32, 7, "yes", 4, 16, 4, 4),
(12, "male", 37, 15, "no", 3, 14, 6, 2),
(1, "male", 42, 15, "yes", 3, 16, 6, 3),
(1, "male", 27, 4, "yes", 1, 18, 5, 4),
(2, "male", 37, 15, "yes", 4, 20, 7, 3),
(7, "male", 37, 15, "yes", 3, 20, 6, 4),
(3, "male", 22, 1.5, "no", 2, 12, 3, 3),
(3, "male", 32, 4, "yes", 3, 20, 6, 2),
(2, "male", 32, 15, "yes", 5, 20, 6, 5),
(12, "female", 52, 15, "yes", 1, 18, 5, 5),
(12, "male", 47, 15, "no", 1, 18, 6, 5),
(3, "female", 32, 15, "yes", 4, 16, 4, 4),
(7, "female", 32, 15, "yes", 3, 14, 3, 2),
(7, "female", 27, 7, "yes", 4, 16, 1, 2),
(12, "male", 42, 15, "yes", 3, 18, 6, 2),
(7, "female", 42, 15, "yes", 2, 14, 3, 2),
(12, "male", 27, 7, "yes", 2, 17, 5, 4),
(3, "male", 32, 10, "yes", 4, 14, 4, 3),
(7, "male", 47, 15, "yes", 3, 16, 4, 2),
(1, "male", 22, 1.5, "yes", 1, 12, 2, 5),
(7, "female", 32, 10, "yes", 2, 18, 5, 4),
(2, "male", 32, 10, "yes", 2, 17, 6, 5),
(2, "male", 22, 7, "yes", 3, 18, 6, 2),
(1, "female", 32, 15, "yes", 3, 14, 1, 5)) val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

随机森林建模

data.createOrReplaceTempView("data") 

// 字符类型转换成数值
val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label"
val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender"
val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children" val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data") val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") // 字段转换成特征向量
val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")
val vecDF: DataFrame = assembler.transform(dataLabelDF)
vecDF.show(10, truncate = false) // 将数据分为训练和测试集(30%进行测试)
val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) // 索引标签,将元数据添加到标签列中
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)
//labelIndexer.transform(vecDF).show(10, truncate = false) // 自动识别分类的特征,并对它们进行索引
// 具有大于5个不同的值的特征被视为连续。
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF)
//featureIndexer.transform(vecDF).show(10, truncate = false) // 训练随机森林模型
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10) // 将索引标签转换回原始标签
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline.
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // Train model. This also runs the indexers.
val model = pipeline.fit(trainingDF) // 输出随机森林模型的全部参数值
model.stages(2).extractParamMap() // 作出预测
val predictions = model.transform(testDF) // Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(10, false) // 选择(预测标签,实际标签),并计算测试误差
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy)) // 这里的stages(2)中的“2”对应pipeline中的“rf”,将model强制转换为RandomForestClassificationModel类型
val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
println("Learned classification forest model:\n" + rfModel.toDebugString)

代码执行结果

vecDF.show(10, truncate = false)
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
|0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|
|0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |
|0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|
|0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|
|0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|
|0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |
|0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|
|0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|
|0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|
|0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
only showing top 10 rows // 将数据分为训练和测试集(30%进行测试)
val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3))
trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields]
testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields] // 索引标签,将元数据添加到标签列中
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)
labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_37df210602df
//labelIndexer.transform(vecDF).show(10, truncate = false) // 自动识别分类的特征,并对它们进行索引
// 具有大于5个不同的值的特征被视为连续。
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF)
featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_9595c228f520
//featureIndexer.transform(vecDF).show(10, truncate = false) // 训练随机森林模型
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10)
rf: org.apache.spark.ml.classification.RandomForestClassifier = rfc_d0e7623d0b10 // 将索引标签转换回原始标签
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_32d6938f2c94 // Chain indexers and forest in a Pipeline.
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
pipeline: org.apache.spark.ml.Pipeline = pipeline_97716da42fed // Train model. This also runs the indexers.
val model = pipeline.fit(trainingDF)
model: org.apache.spark.ml.PipelineModel = pipeline_97716da42fed // 输出随机森林模型的全部参数值
model.stages(2).extractParamMap()
res10: org.apache.spark.ml.param.ParamMap =
{
rfc_0d830180d598-cacheNodeIds: false,
rfc_0d830180d598-checkpointInterval: 10,
rfc_0d830180d598-featureSubsetStrategy: auto,
rfc_0d830180d598-featuresCol: indexedFeatures,
rfc_0d830180d598-impurity: gini,
rfc_0d830180d598-labelCol: indexedLabel,
rfc_0d830180d598-maxBins: 32,
rfc_0d830180d598-maxDepth: 5,
rfc_0d830180d598-maxMemoryInMB: 256,
rfc_0d830180d598-minInfoGain: 0.0,
rfc_0d830180d598-minInstancesPerNode: 1,
rfc_0d830180d598-predictionCol: prediction,
rfc_0d830180d598-probabilityCol: probability,
rfc_0d830180d598-rawPredictionCol: rawPrediction,
rfc_0d830180d598-seed: 207336481,
rfc_0d830180d598-subsamplingRate: 1.0
} // 作出预测
val predictions = model.transform(testDF)
predictions: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 14 more fields] predictions.select("predictedLabel", "label", "features").show(10,false)
+--------------+-----+-------------------------------------+
|predictedLabel|label|features |
+--------------+-----+-------------------------------------+
|0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,12.0,4.0,5.0]|
|0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]|
|0.0 |0.0 |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]|
|0.0 |0.0 |[0.0,22.0,0.417,0.0,4.0,14.0,5.0,5.0]|
|0.0 |0.0 |[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]|
|0.0 |0.0 |[0.0,22.0,0.75,0.0,5.0,18.0,1.0,5.0] |
|0.0 |0.0 |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0] |
|0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,16.0,5.0,3.0] |
|0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,17.0,5.0,5.0] |
|0.0 |0.0 |[0.0,22.0,1.5,1.0,3.0,12.0,1.0,3.0] |
+--------------+-----+-------------------------------------+
only showing top 10 rows // 选择(预测标签,实际标签),并计算测试误差
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_13a195abc422 val accuracy = evaluator.evaluate(predictions)
accuracy: Double = 0.7365591397849462 println("Test Error = " + (1.0 - accuracy))
Test Error = 0.26344086021505375 // 这里的stages(2)中的“2”对应pipeline中的“rf”,将model强制转换为RandomForestClassificationModel类型
val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
rfModel: org.apache.spark.ml.classification.RandomForestClassificationModel = RandomForestClassificationModel (uid=rfc_f7bb5e488533) with 10 trees println("Learned classification forest model:\n" + rfModel.toDebugString)
Learned classification forest model:
RandomForestClassificationModel (uid=rfc_f7bb5e488533) with 10 trees
Tree 0 (weight 1.0):
If (feature 2 <= 1.5)
If (feature 5 <= 12.0)
If (feature 6 <= 1.0)
Predict: 0.0
Else (feature 6 > 1.0)
If (feature 2 <= 0.125)
Predict: 0.0
Else (feature 2 > 0.125)
Predict: 1.0
Else (feature 5 > 12.0)
If (feature 0 in {0.0})
If (feature 5 <= 16.0)
Predict: 0.0
Else (feature 5 > 16.0)
If (feature 1 <= 22.0)
Predict: 0.0
Else (feature 1 > 22.0)
Predict: 0.0
Else (feature 0 not in {0.0})
If (feature 2 <= 0.75)
If (feature 4 in {0.0,1.0,2.0,4.0})
Predict: 0.0
Else (feature 4 not in {0.0,1.0,2.0,4.0})
Predict: 0.0
Else (feature 2 > 0.75)
If (feature 1 <= 22.0)
Predict: 0.0
Else (feature 1 > 22.0)
Predict: 1.0
Else (feature 2 > 1.5)
If (feature 1 <= 42.0)
If (feature 1 <= 27.0)
If (feature 5 <= 16.0)
If (feature 6 <= 5.0)
Predict: 0.0
Else (feature 6 > 5.0)
Predict: 1.0
Else (feature 5 > 16.0)
If (feature 4 in {3.0})
Predict: 0.0
Else (feature 4 not in {3.0})
Predict: 0.0
Else (feature 1 > 27.0)
If (feature 4 in {0.0,3.0,4.0})
If (feature 2 <= 4.0)
Predict: 1.0
Else (feature 2 > 4.0)
Predict: 0.0
Else (feature 4 not in {0.0,3.0,4.0})
If (feature 6 <= 4.0)
Predict: 0.0
Else (feature 6 > 4.0)
Predict: 1.0
Else (feature 1 > 42.0)
If (feature 4 in {2.0,4.0})
Predict: 0.0
Else (feature 4 not in {2.0,4.0})
If (feature 4 in {0.0})
Predict: 1.0
Else (feature 4 not in {0.0})
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
Predict: 0.0
Tree 1 (weight 1.0):
If (feature 7 in {0.0,2.0,4.0})
If (feature 7 in {0.0})
If (feature 1 <= 42.0)
If (feature 4 in {1.0})
Predict: 0.0
Else (feature 4 not in {1.0})
Predict: 1.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 7 not in {0.0})
If (feature 1 <= 17.5)
If (feature 4 in {3.0})
Predict: 0.0
Else (feature 4 not in {3.0})
Predict: 1.0
Else (feature 1 > 17.5)
If (feature 0 in {0.0})
If (feature 4 in {1.0,3.0,4.0})
Predict: 0.0
Else (feature 4 not in {1.0,3.0,4.0})
Predict: 0.0
Else (feature 0 not in {0.0})
If (feature 6 <= 2.0)
Predict: 1.0
Else (feature 6 > 2.0)
Predict: 0.0
Else (feature 7 not in {0.0,2.0,4.0})
If (feature 3 in {0.0})
If (feature 5 <= 14.0)
If (feature 4 in {1.0,3.0})
Predict: 0.0
Else (feature 4 not in {1.0,3.0})
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
Predict: 1.0
Else (feature 5 > 14.0)
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
If (feature 4 in {0.0,2.0,3.0,4.0})
Predict: 0.0
Else (feature 4 not in {0.0,2.0,3.0,4.0})
Predict: 1.0
Else (feature 3 not in {0.0})
If (feature 5 <= 12.0)
If (feature 0 in {1.0})
Predict: 0.0
Else (feature 0 not in {1.0})
If (feature 6 <= 1.0)
Predict: 0.0
Else (feature 6 > 1.0)
Predict: 0.0
Else (feature 5 > 12.0)
If (feature 4 in {0.0,2.0,3.0,4.0})
If (feature 1 <= 47.0)
Predict: 0.0
Else (feature 1 > 47.0)
Predict: 1.0
Else (feature 4 not in {0.0,2.0,3.0,4.0})
If (feature 1 <= 22.0)
Predict: 1.0
Else (feature 1 > 22.0)
Predict: 0.0
Tree 2 (weight 1.0):
If (feature 7 in {0.0})
If (feature 4 in {1.0})
Predict: 0.0
Else (feature 4 not in {1.0})
If (feature 6 <= 5.0)
If (feature 1 <= 42.0)
Predict: 1.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 6 > 5.0)
Predict: 0.0
Else (feature 7 not in {0.0})
If (feature 5 <= 16.0)
If (feature 7 in {1.0})
If (feature 6 <= 4.0)
If (feature 2 <= 7.0)
Predict: 0.0
Else (feature 2 > 7.0)
Predict: 1.0
Else (feature 6 > 4.0)
Predict: 1.0
Else (feature 7 not in {1.0})
If (feature 3 in {1.0})
If (feature 1 <= 17.5)
Predict: 1.0
Else (feature 1 > 17.5)
Predict: 0.0
Else (feature 3 not in {1.0})
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
Predict: 0.0
Else (feature 5 > 16.0)
If (feature 3 in {0.0})
If (feature 4 in {4.0})
Predict: 0.0
Else (feature 4 not in {4.0})
If (feature 5 <= 18.0)
Predict: 0.0
Else (feature 5 > 18.0)
Predict: 0.0
Else (feature 3 not in {0.0})
If (feature 4 in {0.0,3.0,4.0})
If (feature 7 in {2.0})
Predict: 0.0
Else (feature 7 not in {2.0})
Predict: 0.0
Else (feature 4 not in {0.0,3.0,4.0})
If (feature 6 <= 4.0)
Predict: 0.0
Else (feature 6 > 4.0)
Predict: 1.0
Tree 3 (weight 1.0):
If (feature 3 in {0.0})
If (feature 7 in {3.0})
Predict: 0.0
Else (feature 7 not in {3.0})
If (feature 2 <= 10.0)
If (feature 4 in {2.0,3.0,4.0})
If (feature 4 in {4.0})
Predict: 0.0
Else (feature 4 not in {4.0})
Predict: 0.0
Else (feature 4 not in {2.0,3.0,4.0})
If (feature 7 in {0.0,2.0,4.0})
Predict: 0.0
Else (feature 7 not in {0.0,2.0,4.0})
Predict: 1.0
Else (feature 2 > 10.0)
Predict: 1.0
Else (feature 3 not in {0.0})
If (feature 6 <= 2.0)
If (feature 5 <= 16.0)
If (feature 7 in {0.0,1.0,2.0,4.0})
If (feature 4 in {0.0,1.0,3.0,4.0})
Predict: 0.0
Else (feature 4 not in {0.0,1.0,3.0,4.0})
Predict: 1.0
Else (feature 7 not in {0.0,1.0,2.0,4.0})
If (feature 1 <= 22.0)
Predict: 0.0
Else (feature 1 > 22.0)
Predict: 0.0
Else (feature 5 > 16.0)
If (feature 7 in {0.0,1.0,3.0})
Predict: 0.0
Else (feature 7 not in {0.0,1.0,3.0})
Predict: 1.0
Else (feature 6 > 2.0)
If (feature 4 in {0.0,3.0,4.0})
If (feature 7 in {0.0,2.0,3.0,4.0})
If (feature 4 in {3.0,4.0})
Predict: 0.0
Else (feature 4 not in {3.0,4.0})
Predict: 0.0
Else (feature 7 not in {0.0,2.0,3.0,4.0})
If (feature 6 <= 4.0)
Predict: 0.0
Else (feature 6 > 4.0)
Predict: 1.0
Else (feature 4 not in {0.0,3.0,4.0})
If (feature 1 <= 22.0)
If (feature 5 <= 14.0)
Predict: 1.0
Else (feature 5 > 14.0)
Predict: 1.0
Else (feature 1 > 22.0)
If (feature 6 <= 6.0)
Predict: 0.0
Else (feature 6 > 6.0)
Predict: 1.0
Tree 4 (weight 1.0):
If (feature 7 in {0.0,2.0,4.0})
If (feature 7 in {0.0})
If (feature 6 <= 5.0)
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
If (feature 4 in {2.0,4.0})
Predict: 1.0
Else (feature 4 not in {2.0,4.0})
Predict: 1.0
Else (feature 6 > 5.0)
Predict: 0.0
Else (feature 7 not in {0.0})
If (feature 2 <= 1.5)
If (feature 5 <= 12.0)
If (feature 2 <= 0.125)
Predict: 0.0
Else (feature 2 > 0.125)
Predict: 0.0
Else (feature 5 > 12.0)
If (feature 1 <= 17.5)
Predict: 1.0
Else (feature 1 > 17.5)
Predict: 0.0
Else (feature 2 > 1.5)
If (feature 2 <= 7.0)
If (feature 4 in {1.0,3.0,4.0})
Predict: 0.0
Else (feature 4 not in {1.0,3.0,4.0})
Predict: 0.0
Else (feature 2 > 7.0)
If (feature 5 <= 16.0)
Predict: 0.0
Else (feature 5 > 16.0)
Predict: 0.0
Else (feature 7 not in {0.0,2.0,4.0})
If (feature 5 <= 12.0)
Predict: 0.0
Else (feature 5 > 12.0)
If (feature 4 in {0.0,3.0,4.0})
If (feature 1 <= 47.0)
If (feature 1 <= 22.0)
Predict: 0.0
Else (feature 1 > 22.0)
Predict: 0.0
Else (feature 1 > 47.0)
Predict: 1.0
Else (feature 4 not in {0.0,3.0,4.0})
If (feature 1 <= 27.0)
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
Predict: 0.0
Else (feature 1 > 27.0)
If (feature 5 <= 14.0)
Predict: 1.0
Else (feature 5 > 14.0)
Predict: 1.0
Tree 5 (weight 1.0):
If (feature 7 in {0.0})
If (feature 1 <= 42.0)
If (feature 6 <= 4.0)
Predict: 1.0
Else (feature 6 > 4.0)
If (feature 4 in {1.0})
Predict: 0.0
Else (feature 4 not in {1.0})
Predict: 1.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 7 not in {0.0})
If (feature 2 <= 1.5)
If (feature 4 in {0.0,2.0,3.0})
If (feature 1 <= 22.0)
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
Predict: 0.0
Else (feature 1 > 22.0)
Predict: 0.0
Else (feature 4 not in {0.0,2.0,3.0})
If (feature 1 <= 17.5)
If (feature 6 <= 4.0)
Predict: 1.0
Else (feature 6 > 4.0)
Predict: 0.0
Else (feature 1 > 17.5)
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
Predict: 0.0
Else (feature 2 > 1.5)
If (feature 6 <= 5.0)
If (feature 5 <= 17.0)
If (feature 7 in {2.0,4.0})
Predict: 0.0
Else (feature 7 not in {2.0,4.0})
Predict: 0.0
Else (feature 5 > 17.0)
If (feature 6 <= 1.0)
Predict: 0.0
Else (feature 6 > 1.0)
Predict: 0.0
Else (feature 6 > 5.0)
If (feature 4 in {0.0,3.0,4.0})
If (feature 7 in {3.0,4.0})
Predict: 0.0
Else (feature 7 not in {3.0,4.0})
Predict: 0.0
Else (feature 4 not in {0.0,3.0,4.0})
If (feature 6 <= 6.0)
Predict: 0.0
Else (feature 6 > 6.0)
Predict: 0.0
Tree 6 (weight 1.0):
If (feature 4 in {0.0,3.0,4.0})
If (feature 5 <= 12.0)
If (feature 7 in {1.0,2.0,3.0,4.0})
Predict: 0.0
Else (feature 7 not in {1.0,2.0,3.0,4.0})
If (feature 6 <= 3.0)
Predict: 0.0
Else (feature 6 > 3.0)
Predict: 1.0
Else (feature 5 > 12.0)
If (feature 7 in {0.0,1.0,2.0})
If (feature 6 <= 1.0)
If (feature 7 in {0.0,2.0})
Predict: 0.0
Else (feature 7 not in {0.0,2.0})
Predict: 0.0
Else (feature 6 > 1.0)
If (feature 1 <= 37.0)
Predict: 1.0
Else (feature 1 > 37.0)
Predict: 0.0
Else (feature 7 not in {0.0,1.0,2.0})
If (feature 1 <= 17.5)
If (feature 4 in {3.0})
Predict: 0.0
Else (feature 4 not in {3.0})
Predict: 1.0
Else (feature 1 > 17.5)
If (feature 6 <= 4.0)
Predict: 0.0
Else (feature 6 > 4.0)
Predict: 0.0
Else (feature 4 not in {0.0,3.0,4.0})
If (feature 7 in {0.0,4.0})
If (feature 5 <= 12.0)
If (feature 2 <= 0.125)
Predict: 0.0
Else (feature 2 > 0.125)
If (feature 1 <= 17.5)
Predict: 1.0
Else (feature 1 > 17.5)
Predict: 0.0
Else (feature 5 > 12.0)
If (feature 7 in {0.0})
If (feature 1 <= 42.0)
Predict: 1.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 7 not in {0.0})
If (feature 2 <= 1.5)
Predict: 0.0
Else (feature 2 > 1.5)
Predict: 0.0
Else (feature 7 not in {0.0,4.0})
If (feature 6 <= 4.0)
If (feature 7 in {3.0})
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
Predict: 0.0
Else (feature 7 not in {3.0})
If (feature 5 <= 16.0)
Predict: 0.0
Else (feature 5 > 16.0)
Predict: 1.0
Else (feature 6 > 4.0)
If (feature 6 <= 6.0)
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
Predict: 1.0
Else (feature 6 > 6.0)
If (feature 5 <= 18.0)
Predict: 1.0
Else (feature 5 > 18.0)
Predict: 0.0
Tree 7 (weight 1.0):
If (feature 7 in {0.0,2.0,4.0})
If (feature 2 <= 1.5)
If (feature 4 in {1.0,2.0,3.0})
If (feature 1 <= 17.5)
Predict: 1.0
Else (feature 1 > 17.5)
Predict: 0.0
Else (feature 4 not in {1.0,2.0,3.0})
If (feature 5 <= 14.0)
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
Predict: 1.0
Else (feature 5 > 14.0)
Predict: 0.0
Else (feature 2 > 1.5)
If (feature 7 in {0.0,2.0})
If (feature 4 in {1.0,3.0,4.0})
If (feature 5 <= 16.0)
Predict: 0.0
Else (feature 5 > 16.0)
Predict: 0.0
Else (feature 4 not in {1.0,3.0,4.0})
If (feature 6 <= 5.0)
Predict: 1.0
Else (feature 6 > 5.0)
Predict: 0.0
Else (feature 7 not in {0.0,2.0})
If (feature 4 in {0.0,1.0,3.0})
If (feature 1 <= 42.0)
Predict: 0.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 4 not in {0.0,1.0,3.0})
If (feature 5 <= 16.0)
Predict: 0.0
Else (feature 5 > 16.0)
Predict: 0.0
Else (feature 7 not in {0.0,2.0,4.0})
If (feature 2 <= 0.75)
Predict: 0.0
Else (feature 2 > 0.75)
If (feature 4 in {4.0})
If (feature 6 <= 5.0)
If (feature 1 <= 37.0)
Predict: 1.0
Else (feature 1 > 37.0)
Predict: 0.0
Else (feature 6 > 5.0)
Predict: 0.0
Else (feature 4 not in {4.0})
If (feature 5 <= 12.0)
If (feature 1 <= 27.0)
Predict: 0.0
Else (feature 1 > 27.0)
Predict: 0.0
Else (feature 5 > 12.0)
If (feature 7 in {1.0})
Predict: 1.0
Else (feature 7 not in {1.0})
Predict: 0.0
Tree 8 (weight 1.0):
If (feature 5 <= 16.0)
If (feature 4 in {0.0,1.0})
If (feature 0 in {0.0})
If (feature 2 <= 0.75)
If (feature 1 <= 17.5)
Predict: 1.0
Else (feature 1 > 17.5)
Predict: 0.0
Else (feature 2 > 0.75)
If (feature 6 <= 4.0)
Predict: 0.0
Else (feature 6 > 4.0)
Predict: 0.0
Else (feature 0 not in {0.0})
If (feature 5 <= 12.0)
Predict: 1.0
Else (feature 5 > 12.0)
If (feature 7 in {2.0,4.0})
Predict: 0.0
Else (feature 7 not in {2.0,4.0})
Predict: 0.0
Else (feature 4 not in {0.0,1.0})
If (feature 7 in {0.0,2.0,3.0,4.0})
If (feature 1 <= 22.0)
If (feature 6 <= 3.0)
Predict: 0.0
Else (feature 6 > 3.0)
Predict: 0.0
Else (feature 1 > 22.0)
If (feature 6 <= 6.0)
Predict: 0.0
Else (feature 6 > 6.0)
Predict: 1.0
Else (feature 7 not in {0.0,2.0,3.0,4.0})
If (feature 1 <= 42.0)
If (feature 6 <= 4.0)
Predict: 0.0
Else (feature 6 > 4.0)
Predict: 1.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 5 > 16.0)
If (feature 5 <= 18.0)
If (feature 4 in {3.0})
If (feature 7 in {1.0,2.0,3.0})
Predict: 0.0
Else (feature 7 not in {1.0,2.0,3.0})
If (feature 6 <= 5.0)
Predict: 0.0
Else (feature 6 > 5.0)
Predict: 0.0
Else (feature 4 not in {3.0})
If (feature 2 <= 0.75)
Predict: 0.0
Else (feature 2 > 0.75)
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
Predict: 1.0
Else (feature 5 > 18.0)
If (feature 1 <= 27.0)
If (feature 7 in {3.0})
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
Predict: 1.0
Else (feature 7 not in {3.0})
If (feature 2 <= 4.0)
Predict: 0.0
Else (feature 2 > 4.0)
Predict: 1.0
Else (feature 1 > 27.0)
If (feature 6 <= 5.0)
If (feature 6 <= 4.0)
Predict: 0.0
Else (feature 6 > 4.0)
Predict: 0.0
Else (feature 6 > 5.0)
If (feature 4 in {3.0,4.0})
Predict: 0.0
Else (feature 4 not in {3.0,4.0})
Predict: 0.0
Tree 9 (weight 1.0):
If (feature 5 <= 16.0)
If (feature 6 <= 2.0)
If (feature 1 <= 42.0)
If (feature 6 <= 1.0)
If (feature 5 <= 9.0)
Predict: 1.0
Else (feature 5 > 9.0)
Predict: 0.0
Else (feature 6 > 1.0)
If (feature 1 <= 27.0)
Predict: 0.0
Else (feature 1 > 27.0)
Predict: 1.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 6 > 2.0)
If (feature 1 <= 27.0)
If (feature 5 <= 14.0)
If (feature 6 <= 3.0)
Predict: 0.0
Else (feature 6 > 3.0)
Predict: 0.0
Else (feature 5 > 14.0)
Predict: 0.0
Else (feature 1 > 27.0)
If (feature 4 in {1.0,2.0,4.0})
If (feature 5 <= 9.0)
Predict: 0.0
Else (feature 5 > 9.0)
Predict: 0.0
Else (feature 4 not in {1.0,2.0,4.0})
If (feature 7 in {2.0,3.0,4.0})
Predict: 0.0
Else (feature 7 not in {2.0,3.0,4.0})
Predict: 1.0
Else (feature 5 > 16.0)
If (feature 6 <= 4.0)
If (feature 4 in {3.0})
Predict: 0.0
Else (feature 4 not in {3.0})
If (feature 1 <= 42.0)
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
Predict: 0.0
Else (feature 1 > 42.0)
Predict: 1.0
Else (feature 6 > 4.0)
If (feature 4 in {3.0,4.0})
If (feature 1 <= 37.0)
If (feature 3 in {0.0})
Predict: 0.0
Else (feature 3 not in {0.0})
Predict: 0.0
Else (feature 1 > 37.0)
If (feature 1 <= 42.0)
Predict: 0.0
Else (feature 1 > 42.0)
Predict: 0.0
Else (feature 4 not in {3.0,4.0})
If (feature 4 in {0.0,2.0})
If (feature 7 in {0.0,1.0,2.0})
Predict: 1.0
Else (feature 7 not in {0.0,1.0,2.0})
Predict: 1.0
Else (feature 4 not in {0.0,2.0})
If (feature 0 in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
Predict: 0.0

随机森林模型调优

// 字段转换成特征向量
val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")
val vecDF: DataFrame = assembler.transform(dataLabelDF)
vecDF.show(10, truncate = false) // 将数据分为训练和测试集(30%进行测试)
val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) // 索引标签,将元数据添加到标签列中
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)
//labelIndexer.transform(vecDF).show(10, truncate = false) // 自动识别分类的特征,并对它们进行索引
// 具有大于5个不同的值的特征被视为连续。
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF)
//featureIndexer.transform(vecDF).show(10, truncate = false) // 训练随机森林模型
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures") // 将索引标签转换回原始标签
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline.
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // 设置参数网格
//impurity 不纯度
//maxBins 离散化"连续特征"的最大划分数
//maxDepth 树的最大深度
//minInfoGain 一个节点分裂的最小信息增益,值为[0,1]
//minInstancesPerNode 每个节点包含的最小样本数 >=1
//numTrees 树的数量
//featureSubsetStrategy // 在每个树节点处分割的特征数,参数值比较多,详细的请参考官方文档
//SubsamplingRate(1.0) 给每棵树分配“学习数据”的比例,范围(0, 1]
//maxMemoryInMB 如果太小,则每次迭代将拆分1个节点,其聚合可能超过此大小。
//checkpointInterval 设置检查点间隔(> = 1)或禁用检查点(-1)。 例如 10意味着,每10次迭代,缓存将获得检查点。
//cacheNodeIds 如果为false,则算法将树传递给执行器以将实例与节点匹配。 如果为true,算法将缓存每个实例的节点ID。 缓存可以加速更大深度的树的训练。 用户可以通过设置checkpointInterval来设置检查或禁用缓存的频率。(default = false)
//seed 种子
val paramGrid = new ParamGridBuilder()
.addGrid(rf.impurity, Array("entropy", "gini"))
.addGrid(rf.maxBins, Array(32, 64))
.addGrid(rf.maxDepth, Array(5, 7, 10))
.addGrid(rf.minInfoGain, Array(0, 0.5, 1))
.addGrid(rf.minInstancesPerNode, Array(10, 20))
.addGrid(rf.numTrees, Array(20, 50))
.addGrid(rf.featureSubsetStrategy, Array("auto", "sqrt"))
.addGrid(rf.subsamplingRate, Array(0.8, 1))
.addGrid(rf.maxMemoryInMB, Array(256, 512))
.addGrid(rf.checkpointInterval, Array(10, 20))
.addGrid(rf.cacheNodeIds, Array(false, true))
.addGrid(rf.seed, Array(123456L, 111L))
.build() // 选择(预测标签,实际标签),并计算测试误差。indexedLabel与prediction都是索引化的,因此可以直接比较
val classEvaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") // 设置交叉验证
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(classEvaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5) // 执行交叉验证,并选择出最好的参数集合
val cvModel = cv.fit(trainingDF) // 查看全部参数
cvModel.extractParamMap()
// cvModel.avgMetrics.length=cvModel.getEstimatorParamMaps.length
// cvModel.avgMetrics与cvModel.getEstimatorParamMaps中的元素一一对应
cvModel.avgMetrics.length
cvModel.avgMetrics // 参数对应的平均度量 cvModel.getEstimatorParamMaps.length
cvModel.getEstimatorParamMaps // 参数组合的集合 cvModel.getEvaluator.extractParamMap() // 评估的参数 cvModel.getEvaluator.isLargerBetter // 评估的度量值是大的好,还是小的好 ,根据评估度量,系统会自动识别
cvModel.getNumFolds // 交叉验证的折数 //################################
// 测试模型
val predictDF: DataFrame = cvModel.transform(testDF).selectExpr(
//"race","poverty","smoke","alcohol","agemth","ybirth","yschool","pc3mth", "features",
"predictedLabel", "label", "features")
predictDF.show(20, false)