一、创建Relation
package com.spark.datasource.demo;
import org.apache.spark.sql.sources._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.rdd.RDD
import java.sql.{ DriverManager, ResultSet }
import org.apache.spark.sql.{ Row, SQLContext }
import scala.collection.mutable.ArrayBuffer
import org.slf4j.LoggerFactory
import java.io._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import scala.collection.JavaConversions._
/** * implement user define dataSource need steps * 1.1 create DefaultSource extends RelationProvider . * class name must be DefaultSource * 1.2 implement user define Relation * Relation support 4 scanning strategies * <1> full table scan , need extend TableScan * <2> column scan , need extend PrunedScan * <3> column scan + filter row , need extend PrunedFilterScan * <4> CatalystScan * 1.3 implement user define RDD * 1.4 implement user define RddPatertion * 1.5 implement user define RDD Iterator */
class DefaultSource extends RelationProvider
with SchemaRelationProvider with CreatableRelationProvider {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
createRelation(sqlContext, parameters, null)
}
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
return MyRelation(parameters, schema)(sqlContext)
}
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
createRelation(sqlContext, parameters, data.schema)
}
}
case class MyRelation(@transient val parameters: Map[String, String], @transient userSchema: StructType)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with PrunedScan with PrunedFilteredScan with Serializable {
private val logger = LoggerFactory.getLogger(getClass)
private val sparkContext = sqlContext.sparkContext
def printStackTraceStr(e: Exception, data: String) = {
val sw: StringWriter = new StringWriter()
val pw: PrintWriter = new PrintWriter(sw)
e.printStackTrace(pw)
println("======>>printStackTraceStr Exception: " + e.getClass() + "\n==>" + sw.toString() + "\n==>data=" + data)
}
override def schema: StructType = {
if (this.userSchema != null) {
return this.userSchema
} else {
return StructType(Seq(StructField("data", IntegerType)))
}
}
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
logger.info("unhandledFilters with filters " + filters.toList)
// unhandled function return true spark deal with filter
// otherwise data source deal with
def unhandled(filter: Filter): Boolean = {
filter match {
case EqualTo(col, v) => {
println("EqualTo col is :" + col + " value is :" + v)
true
}
case _ => true
}
}
filters.filter(unhandled)
}
override def buildScan(): RDD[Row] = {
logger.info("Table Scan buildScan ")
return new MyRDD[Row](sparkContext)
}
override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
logger.info("pruned build scan for columns " + requiredColumns.toList)
return new MyRDD[Row](sparkContext)
}
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
logger.info("prunedfilteredScan build scan for columns " + requiredColumns.toList + "with filters " + filters.toList)
return new MyRDD[Row](sparkContext)
}
}
二、创建RDD和Partition
package com.spark.datasource.demo
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapred.Reporter
import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.util.NextIterator
import scala.reflect.ClassTag
import org.apache.spark.sql.{ Row, SQLContext }
import org.slf4j.LoggerFactory
import org.apache.spark.sql.types._
import scala.collection.JavaConversions._
case class MyPartition(index: Int) extends Partition {
}
class MyRDD[T: ClassTag](
@transient private val _sc: SparkContext) extends RDD[T](_sc, Nil) {
private val logger = LoggerFactory.getLogger(getClass)
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
logger.warn("call MyRDD compute function ")
val currSplit = split.asInstanceOf[MyPartition]
new MyIterator(currSplit,context)
}
override protected def getPartitions: Array[Partition] = {
logger.warn("call MyRDD getPartitions function ")
val partitions = new Array[Partition](1)
partitions(0) = new MyPartition(1)
partitions
}
override protected def getPreferredLocations(split: Partition): Seq[String] = {
logger.warn("call MyRDD getPreferredLocations function")
val currSplit = split.asInstanceOf[MyPartition]
Seq("localhost")
}
}
三、创建Iterator
package com.spark.datasource.demo
import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.util.NextIterator
import scala.reflect.ClassTag
import org.apache.spark.sql.{ Row, SQLContext }
import org.slf4j.LoggerFactory
import org.apache.spark.sql.types._
import java.io._
class MyIterator[T: ClassTag](
split: MyPartition,
context: TaskContext) extends Iterator[T] {
private val logger = LoggerFactory.getLogger(getClass)
private val currSplit = split.asInstanceOf[MyPartition]
private var index = 0 ;
override def hasNext: Boolean = {
if(index == 1) {
return false
}
index = index + 1
return true
}
override def next(): T = {
val r = Row(100000)
r.asInstanceOf[T]
}
}
四、Eclipse截图
五、SBT目录结构
build.sbt代码
name := "SparkDataSourceDemo"
version := "0.1"
organization := "com.spark.datasource.demo"
scalaVersion := "2.10.4"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "1.6.0" % "provided"
resolvers += "Spark Staging Repository" at "https://repository.apache.org/content/repositories/orgapachespark-1038/"
publishMavenStyle := true
publishTo := {
val nexus = "https://oss.sonatype.org/"
if (version.value.endsWith("SNAPSHOT"))
Some("snapshots" at nexus + "content/repositories/snapshots")
else
Some("releases" at nexus + "service/local/staging/deploy/maven2")
}
六、SBT打包命令
- 在build.sbt同目录下执行
/usr/local/sbt/sbt package
七、测试运行
1.在build.sbt同目录下执行
/usr/local/spark/bin/spark-sql –jars target/scala-2.10/xclouddatasourcespark_2.10-0.1.jar
2.创建表语句
CREATE TEMPORARY TABLE test USING com.spark.datasource.demo OPTIONS ();
select * from test;