逻辑回归 评价推荐质量 计算AUC

时间:2022-12-07 19:15:38
有两个概念需要弄清楚:
1 ROC
2 AUC

package org.apache.spark.mllib.classification

import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.{SparkConf, SparkContext}


/**
  * Created by root on 2016/9/24.
  * 评价推荐质量 :
  *
  */
object auc {

  def main(args: Array[String]) {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)


    val conf = new SparkConf().setAppName("auc").setMaster("local[*]")
    val sc = new SparkContext(conf)
    //原数据格式---> LogisticRegressionDataGenerator 生成的数据 <master> <output_dir> <num_examples> <num_features> <num_partitions>
    //local path 100000 100 1
    //(0.0,[1.1419053154730547,0.9194079489827879,-0.9498666368908959,-1.1069902863993377,0.2809776380727795,0.6846227956326554,-0.8172214073987268,-1.3966434026780434,-0.19094451307087512,1.4862133923906502,0.8023071496873626,-0.12151292466549345,1.4105062239438624,-0.6402327822135738,-1.2096444592532913,0.35375769787202876,-0.4903496491990076,0.5507215382743629,-1.2035510019650835,0.3210160806416416,1.5511476388671834,0.43853028624710505,0.4815980608245389,1.5196310789680683,-0.2768317291873249,-0.08393897849486337,1.255833005788796,-0.3252727938665772,-0.17329033306108363,-1.8585851445864527,1.4238069456328435,-1.363726024075023,-1.964666098753878,-0.9185948439341892,-2.548887393384806,-1.6309606578419305,-0.12200477461989162,1.289159071801577,-0.2691388556559934,0.2574914085090889,-0.3199143760045327,-1.7684998592513064,-0.4834503128592458,-0.5099904653893699,1.1166733769661994,-0.04094720151728288,-1.1076715169200795,1.8623214176471945,1.1457411377091524,-1.0586772048930921,1.0725991339400673,-1.9317441520296659,0.30102521611534994,0.2475231582804265,1.406156849249087,-1.5202207203569256,0.2709294126920897,0.561249284813777,-0.5298295780368607,0.5390221914988275,2.2123402141787243,-0.6329335687728442,-1.8831759122084633,0.3865659853763343,0.32582927090649455,-0.9013043195000002,-0.002680308907617573,-0.4739592549853249,-0.5479781547659026,-0.01910014847196348,1.6468163882596327,-1.107062592215791,0.5938103926672539,-0.15566462108511642,0.6632872929286855,1.226793360688623,0.8839698437730904,0.22172454670212935,0.9197020859698617,-0.7393758185888677,0.803517749531419,-0.2539417447630359,-0.7638388605060555,-1.8645567427274516,-1.861306200027518,-0.576599881116305,-0.40899380621224757,0.24846093761654187,-0.48091295490277447,0.44621205735391023,-0.4465888888803913,0.045638687865053575,0.7045663273135641,-0.2718240183671583,0.08074877915238832,1.2590964696340183,0.7635098382407334,1.7220810801509723,0.14595005405372477,-0.9946630124621867])
    //(1.0,[3.45334487943329,3.703049287361516,2.5130165929545307,3.0800075574677215,3.171627421031228,4.510723893197076,2.6581375743525686,3.1156651789533387,4.47340262637963,2.4049108266235555,4.255158010564988,2.3258025009719385,2.887640312921763,3.982413153985827,4.225500378590178,3.153701048234778,3.3226444299363393,3.1134293366028674,2.7082537229307286,4.181225794555864,2.4777291910879735,3.3541041623501155,2.509535608654829,3.6520241921439593,2.7357293459939838,3.7219713080774572,1.79978608558576,2.1490013130808725,3.0506491504672213,4.282604805576632,2.3453302124452065,4.765552217918703,2.576396589784987,2.2122195494057304,3.777122545381258,1.2790576486032348,2.6425654857274985,3.212018531982713,5.001124074462854,3.1785870770834213,3.9658649468050955,3.991180264704203,1.950852602272741,3.7912958928979705,1.739948405239964,2.9715486731415477,2.6257571982670096,1.826757810204818,3.3903795303152915,2.641366174472212,2.268682514813932,2.302301777328137,1.318053862077372,1.4887538557215252,4.596202655233997,2.708452851966949,2.535659883929069,2.159712854763802,4.3592175241886935,1.2835999065428711,3.8212126122949757,0.9923739981417254,1.2082442788719718,3.09057936644558,2.0470303995207977,2.281734506191289,4.088998687981933,3.3761246055228167,1.5022126296834528,1.6766482086025092,3.0925674715305713,1.6349647008447825,3.204433721898702,2.127761369019482,3.4994059501116883,4.407179021668299,3.257583684085755,3.313665712868691,2.0566775385443052,4.3593727689426345,2.521146674716752,3.261467251430135,1.0213751954482944,3.181153154383373,4.502250899343585,3.2313429276554966,2.8643146434649607,2.6234099229109504,3.7240089562669496,2.7780075804328814,2.709132808617267,3.2789299190580423,2.721715981494141,4.2277698250773055,3.979624410666172,3.3718544694471473,0.7743217015416155,3.61345808170633,2.3576880804445057,3.6785246541125773])

    val data: RDD[LabeledPoint] =MLUtils.loadLabeledPoints(sc,"D:\\bigdataworkspaces\\dataout\\dd\\part-00000")
    //loadlabledPoints
    //(0.0,[1.1419053154730547,0.9194079489827879,-0.9498666368908959,-1.1069902863993377,0.2809776380727795,0.6846227956326554,-0.8172214073987268,-1.3966434026780434,-0.19094451307087512,1.4862133923906502,0.8023071496873626,-0.12151292466549345,1.4105062239438624,-0.6402327822135738,-1.2096444592532913,0.35375769787202876,-0.4903496491990076,0.5507215382743629,-1.2035510019650835,0.3210160806416416,1.5511476388671834,0.43853028624710505,0.4815980608245389,1.5196310789680683,-0.2768317291873249,-0.08393897849486337,1.255833005788796,-0.3252727938665772,-0.17329033306108363,-1.8585851445864527,1.4238069456328435,-1.363726024075023,-1.964666098753878,-0.9185948439341892,-2.548887393384806,-1.6309606578419305,-0.12200477461989162,1.289159071801577,-0.2691388556559934,0.2574914085090889,-0.3199143760045327,-1.7684998592513064,-0.4834503128592458,-0.5099904653893699,1.1166733769661994,-0.04094720151728288,-1.1076715169200795,1.8623214176471945,1.1457411377091524,-1.0586772048930921,1.0725991339400673,-1.9317441520296659,0.30102521611534994,0.2475231582804265,1.406156849249087,-1.5202207203569256,0.2709294126920897,0.561249284813777,-0.5298295780368607,0.5390221914988275,2.2123402141787243,-0.6329335687728442,-1.8831759122084633,0.3865659853763343,0.32582927090649455,-0.9013043195000002,-0.002680308907617573,-0.4739592549853249,-0.5479781547659026,-0.01910014847196348,1.6468163882596327,-1.107062592215791,0.5938103926672539,-0.15566462108511642,0.6632872929286855,1.226793360688623,0.8839698437730904,0.22172454670212935,0.9197020859698617,-0.7393758185888677,0.803517749531419,-0.2539417447630359,-0.7638388605060555,-1.8645567427274516,-1.861306200027518,-0.576599881116305,-0.40899380621224757,0.24846093761654187,-0.48091295490277447,0.44621205735391023,-0.4465888888803913,0.045638687865053575,0.7045663273135641,-0.2718240183671583,0.08074877915238832,1.2590964696340183,0.7635098382407334,1.7220810801509723,0.14595005405372477,-0.9946630124621867])
    //(1.0,[3.45334487943329,3.703049287361516,2.5130165929545307,3.0800075574677215,3.171627421031228,4.510723893197076,2.6581375743525686,3.1156651789533387,4.47340262637963,2.4049108266235555,4.255158010564988,2.3258025009719385,2.887640312921763,3.982413153985827,4.225500378590178,3.153701048234778,3.3226444299363393,3.1134293366028674,2.7082537229307286,4.181225794555864,2.4777291910879735,3.3541041623501155,2.509535608654829,3.6520241921439593,2.7357293459939838,3.7219713080774572,1.79978608558576,2.1490013130808725,3.0506491504672213,4.282604805576632,2.3453302124452065,4.765552217918703,2.576396589784987,2.2122195494057304,3.777122545381258,1.2790576486032348,2.6425654857274985,3.212018531982713,5.001124074462854,3.1785870770834213,3.9658649468050955,3.991180264704203,1.950852602272741,3.7912958928979705,1.739948405239964,2.9715486731415477,2.6257571982670096,1.826757810204818,3.3903795303152915,2.641366174472212,2.268682514813932,2.302301777328137,1.318053862077372,1.4887538557215252,4.596202655233997,2.708452851966949,2.535659883929069,2.159712854763802,4.3592175241886935,1.2835999065428711,3.8212126122949757,0.9923739981417254,1.2082442788719718,3.09057936644558,2.0470303995207977,2.281734506191289,4.088998687981933,3.3761246055228167,1.5022126296834528,1.6766482086025092,3.0925674715305713,1.6349647008447825,3.204433721898702,2.127761369019482,3.4994059501116883,4.407179021668299,3.257583684085755,3.313665712868691,2.0566775385443052,4.3593727689426345,2.521146674716752,3.261467251430135,1.0213751954482944,3.181153154383373,4.502250899343585,3.2313429276554966,2.8643146434649607,2.6234099229109504,3.7240089562669496,2.7780075804328814,2.709132808617267,3.2789299190580423,2.721715981494141,4.2277698250773055,3.979624410666172,3.3718544694471473,0.7743217015416155,3.61345808170633,2.3576880804445057,3.6785246541125773])
    //RDD 的 randomSplit 方法用于将数据集分成 训练集和测试集
    val splitdta = data.randomSplit(Array(1,9))
    val trainData = splitdta(1)
    val testData = splitdta(0)

    //逻辑回归训练 得到model

    val model: LogisticRegressionModel = LogisticRegressionWithSGD.train(trainData,100,0.8)

    //测试数据库标签
    val lable = testData.map(_.label)
    //测试数据集特征值
    val feature = testData.map(_.features)
    //计算每个样本的score 得分
    val score = feature.map(f =>{
      val margin = dot(model.weights, f) + model.intercept
      val sco = 1.0/(1.0+math.exp(-margin))
      sco
    })
    //得到得分后zip 上 label 并排序 ,index 就是rank 值
    val scoreLabel: RDD[((Double, Double), Long)] = score.zip(lable).sortBy(_._1).zipWithIndex()
    scoreLabel.cache()
    //scoreLabel.take(4).foreach(print)

    //正例返回rank 负例返回o ,最后sum 的值就是正例的rank 值之和
    val rank = scoreLabel.map(sample=>if(sample._1._2.equals(1.0)){
      sample._2
    }else{

      0.0
    }).sum()
    //标签中正例值 1.0  - 0.0 sum 的结果就是正例的个数
    val M = lable.sum()
    val N =lable.count()-M
    //带入公式计算auc
    val auc = (rank -M*(M+1)/2)/(M*N)

    println("Auc==" + auc)


    //data.foreach(println)

  }

}



结果:
"C:\Program Files\Java\jdk1.7.0_80\bin\java" -Didea.launcher.port=7532 "-Didea.launcher.bin.path=C:\Program Files (x86)\JetBrains\IntelliJ IDEA Community Edition 2016.1.3\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.7.0_80\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\access-bridge-64.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\dnsns.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\jaccess.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\localedata.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\sunec.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\sunjce_provider.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\sunmscapi.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\ext\zipfs.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\javaws.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\jce.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\jfr.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\jfxrt.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\jsse.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\management-agent.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\plugin.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\resources.jar;C:\Program Files\Java\jdk1.7.0_80\jre\lib\rt.jar;D:\bigdataworkspaces\recommder\out\production\recommder;F:\scala\lib\scala-actors-migration.jar;F:\scala\lib\scala-actors.jar;F:\scala\lib\scala-library.jar;F:\scala\lib\scala-reflect.jar;F:\scala\lib\scala-swing.jar;D:\bigdataworkspaces\recommder\lib\spark-assembly-1.6.0-hadoop2.6.0.jar;C:\Program Files (x86)\JetBrains\IntelliJ IDEA Community Edition 2016.1.3\lib\idea_rt.jar" com.intellij.rt.execution.application.AppMain org.apache.spark.mllib.classification.auc
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
16/09/24 15:21:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
16/09/24 15:21:05 INFO Slf4jLogger: Slf4jLogger started
16/09/24 15:21:05 INFO Remoting: Starting remoting
16/09/24 15:21:06 INFO Remoting: Remoting started; listening on addresses :[akka.tcp://sparkDriverActorSystem@192.168.1.171:16869]
16/09/24 15:21:11 WARN : Your hostname, root resolves to a loopback/non-reachable address: fe80:0:0:0:0:5efe:c0a8:8c01%17, but we couldn't find any external IP address!
16/09/24 15:21:23 INFO FileInputFormat: Total input paths to process : 1
16/09/24 15:21:24 INFO deprecation: mapred.tip.id is deprecated. Instead, use mapreduce.task.id
16/09/24 15:21:24 INFO deprecation: mapred.task.id is deprecated. Instead, use mapreduce.task.attempt.id
16/09/24 15:21:24 INFO deprecation: mapred.task.is.map is deprecated. Instead, use mapreduce.task.ismap
16/09/24 15:21:24 INFO deprecation: mapred.task.partition is deprecated. Instead, use mapreduce.task.partition
16/09/24 15:21:24 INFO deprecation: mapred.job.id is deprecated. Instead, use mapreduce.job.id
16/09/24 15:21:55 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
16/09/24 15:21:55 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
[Stage 209:===============================================>         (5 + 1) / 6]Auc==0.9998032270759544

Process finished with exit code 0