Spark Sql提供了丰富的内置函数让开发者来使用,但实际开发业务场景可能很复杂,内置函数不能够满足业务需求,因此spark sql提供了可扩展的内置函数。
UDF:是普通函数,输入一个或多个参数,返回一个值。比如:len(),isnull()
UDAF:是聚合函数,输入一组值,返回一个聚合结果。比如:max(),avg(),sum()
Spark编写UDF函数
下边的例子是在spark2.0之前的示例:例子中展示只有一个参数输入,和一个参数输出的UDF。
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDF1 {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[2]");
sparkConf.setAppName("spark udf test");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
@SuppressWarnings("deprecation")
SQLContext sqlContext=new SQLContext(javaSparkContext);
JavaRDD<String> javaRDD = javaSparkContext.parallelize(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"));
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
return RowFactory.create(fields);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sqlContext.createDataFrame(rowRDD, schema);
ds.createOrReplaceTempView("user"); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
sqlContext.udf().register("strLength", new UDF1<String, Integer>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public Integer call(String t1) throws Exception {
return t1.length();
}
}, DataTypes.IntegerType); Dataset<Row> rows = sqlContext.sql("select id,name,strLength(name) as length from user");
rows.show(); javaSparkContext.stop();
}
}
输出效果:
+---+--------+------+
| id| name|length|
+---+--------+------+
| 1|zhangsan| 8|
| 2| lisi| 4|
| 3| wangwu| 6|
| 4| zhaoliu| 7|
+---+--------+------+
上边使用UDF展示了:单个输入,单个输出的函数。那么下边将会展示使用spark2.0实现三个输入,一个输出的UDF函数。
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF3;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDF2 {
public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"), Encoders.STRING()); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
sparkSession.udf().register("strLength", new UDF1<String, Integer>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public Integer call(String t1) throws Exception {
return t1.length();
}
}, DataTypes.IntegerType);
sparkSession.udf().register("strConcat", new UDF3<String, String, String, String>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public String call(String combChar, String t1, String t2) throws Exception {
return t1 + combChar + t2;
}
}, DataTypes.StringType); showByStruct(sparkSession, row);
System.out.println("==========================================");
showBySchema(sparkSession, row); sparkSession.stop();
} private static void showBySchema(SparkSession sparkSession, Dataset<String> row) {
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
return RowFactory.create(fields);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('+',id,name) as str from user");
rows.show();
} private static void showByStruct(SparkSession sparkSession, Dataset<String> row) {
JavaRDD<Person> map = row.javaRDD().map(Person::parsePerson);
Dataset<Row> persons = sparkSession.createDataFrame(map, Person.class);
persons.show(); persons.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('-',id,name) as str from user");
rows.show();
}
}
Person.java
package com.dx.streaming.producer; import java.io.Serializable; public class Person implements Serializable{
private String id;
private String name; public Person(String id, String name) {
this.id = id;
this.name = name;
} public String getId() {
return id;
} public void setId(String id) {
this.id = id;
} public String getName() {
return name;
} public void setName(String name) {
this.name = name;
} public static Person parsePerson(String line) {
String[] fields = line.split(",");
Person person = new Person(fields[0], fields[1]);
return person;
}
}
需要注意的地方,我们全局udf函数只需要注册一次,就允许多次调用。
输出效果:
+---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+ +---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1-zhangsan|
| 2| lisi| 4| 2-lisi|
| 3| wangwu| 6| 3-wangwu|
| 4| zhaoliu| 7| 4-zhaoliu|
+---+--------+------+----------+ ========================================== +---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+ +---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1+zhangsan|
| 2| lisi| 4| 2+lisi|
| 3| wangwu| 6| 3+wangwu|
| 4| zhaoliu| 7| 4+zhaoliu|
+---+--------+------+----------+
相信认真阅读的话,通过上边的两个示例,就可以掌握其用法。
Spark编写UDAF函数
自定义聚合函数需要实现UserDefinedAggregateFunction,以下是该抽象类的定义:
package org.apache.spark.sql.expressions import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types._
import org.apache.spark.annotation.Experimental /**
* :: Experimental ::
* The base class for implementing user-defined aggregate functions (UDAF).
*/
@Experimental
abstract class UserDefinedAggregateFunction extends Serializable { /**
* A [[StructType]] represents data types of input arguments of this aggregate function.
* For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
* with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this [[StructType]] is only used to identify the corresponding
* input argument. Users can choose names to identify the input arguments.
*/
//输入参数的数据类型定义
def inputSchema: StructType /**
* A [[StructType]] represents data types of values in the aggregation buffer.
* For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
* (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
* the returned [[StructType]] will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this [[StructType]] is only used to identify the corresponding
* buffer value. Users can choose names to identify the input arguments.
*/
//聚合的中间过程中产生的数据的数据类型定义
def bufferSchema: StructType /**
* The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
*/
//聚合结果的数据类型定义
def dataType: DataType /**
* Returns true if this function is deterministic, i.e. given the same input,
* always return the same output.
*/
//一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
def deterministic: Boolean /**
* Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
*
* The contract should be that applying the merge function on two initial buffers should just
* return the initial buffer itself, i.e.
* `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
*/
//设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
def initialize(buffer: MutableAggregationBuffer): Unit
/**
* Updates the given aggregation buffer `buffer` with new input data from `input`.
*
* This is called once per input row.
*/
//用输入数据input更新buffer值,类似于combineByKey
def update(buffer: MutableAggregationBuffer, input: Row): Unit
/**
* Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
*
* This is called when we merge two partially aggregated data together.
*/
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
/**
* Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
* aggregation buffer.
*/
//计算并返回最终的聚合结果
def evaluate(buffer: Row): Any
/**
* Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
*/
//所有输入数据进行聚合
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression2(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
Column(aggregateExpression)
} /**
* Creates a [[Column]] for this UDAF using the distinct values of the given
* [[Column]]s as input arguments.
*/
//所有输入数据去重后进行聚合
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression2(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
Column(aggregateExpression)
}
} /**
* :: Experimental ::
* A [[Row]] representing an mutable aggregation buffer.
*
* This is not meant to be extended outside of Spark.
*/
@Experimental
abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */
def update(i: Int, value: Any): Unit
}
实现单列求平均数的聚合函数:
package com.dx.streaming.producer; import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType; public class SimpleAvg extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L; @Override
public StructType inputSchema() {
StructType structType= new StructType().add("myinput",DataTypes.DoubleType);
return structType;
} @Override
public StructType bufferSchema() {
StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
return structType;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return true;
} //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
} /**
* partitions内部combine
* */
//用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0, buffer.getLong(0)+1); // 條目數+1
buffer.update(1, buffer.getDouble(1)+input.getDouble(0)); // 输入汇总
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
* */
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併
buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
} //计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double avg = buffer.getDouble(1) / buffer.getLong(0);
Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat;
}
}
下边展示下如何使用自定义的UDAF函数:
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80",
"2,zhangsan,History,87",
"3,zhangsan,Chinese,88",
"4,zhangsan,Chemistry,96",
"5,lisi,English,70",
"6,lisi,Chinese,74",
"7,lisi,History,75",
"8,lisi,Chemistry,77",
"9,lisi,Physics,79",
"10,lisi,Biology,82",
"11,wangwu,English,96",
"12,wangwu,Chinese,98",
"13,wangwu,History,91",
"14,zhaoliu,English,68",
"15,zhaoliu,Chinese,66"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve=Double.parseDouble(fields[3]);
return RowFactory.create(id,name,subject,achieve);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new SimpleAvg();
sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve) avg_achieve from user group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve) avg_achieve from user group by name");
rows2.show();
} }
输出结果:
+---+--------+---------+-------+
| id| name| subject|achieve|
+---+--------+---------+-------+
| 1|zhangsan| English| 80.0|
| 2|zhangsan| History| 87.0|
| 3|zhangsan| Chinese| 88.0|
| 4|zhangsan|Chemistry| 96.0|
| 5| lisi| English| 70.0|
| 6| lisi| Chinese| 74.0|
| 7| lisi| History| 75.0|
| 8| lisi|Chemistry| 77.0|
| 9| lisi| Physics| 79.0|
| 10| lisi| Biology| 82.0|
| 11| wangwu| English| 96.0|
| 12| wangwu| Chinese| 98.0|
| 13| wangwu| History| 91.0|
| 14| zhaoliu| English| 68.0|
| 15| zhaoliu| Chinese| 66.0|
+---+--------+---------+-------+ +--------+-----------------+
| name| avg_achieve|
+--------+-----------------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi|76.16666666666667|
+--------+-----------------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi| 76.17|
+--------+-----------+
实现多列之和,再求平均数的UDAF聚合函数:
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80,89",
"2,zhangsan,History,87,88",
"3,zhangsan,Chinese,88,87",
"4,zhangsan,Chemistry,96,95",
"5,lisi,English,70,75",
"6,lisi,Chinese,74,67",
"7,lisi,History,75,80",
"8,lisi,Chemistry,77,70",
"9,lisi,Physics,79,80",
"10,lisi,Biology,82,83",
"11,wangwu,English,96,84",
"12,wangwu,Chinese,98,64",
"13,wangwu,History,91,92",
"14,zhaoliu,English,68,80",
"15,zhaoliu,Chinese,66,69"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve1=Double.parseDouble(fields[3]);
Double achieve2=Double.parseDouble(fields[4]);
return RowFactory.create(id,name,subject,achieve1,achieve2);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilAvg(2);
sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve1+achieve2) avg_achieve from user group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve1,achieve2) avg_achieve from user group by name");
rows2.show();
}
}
上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilAvg实现的就是一个多列求和之后在进行求平均的使用。
MutilAvg.java(udaf函数):
package com.dx.streaming.producer; import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType; public class MutilAvg extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L;
private int columnSize=1; public MutilAvg(int columnSize){
this.columnSize=columnSize;
} @Override
public StructType inputSchema() {
StructType structType= new StructType();
for(int i=0;i<columnSize;i++){
structType.add("myinput"+i,DataTypes.DoubleType);
}
return structType;
} @Override
public StructType bufferSchema() {
StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
return structType;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return true;
} //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
} /**
* partitions内部combine
* */
//用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0, buffer.getLong(0)+1); // 條目數+1 // 输入一行包含多列,因此需要把铜一行的多列合并。
Double currentLineSumValue= 0d;
for(int i=0;i<columnSize;i++){
currentLineSumValue+=input.getDouble(i);
} buffer.update(1, buffer.getDouble(1)+currentLineSumValue); // 输入汇总
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
* */
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併
buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
} //计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double avg = buffer.getDouble(1) / buffer.getLong(0);
Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat;
}
}
测试输出:
+---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+
实现多列分别求最大值,之后再从多列中最大值中找出一个最大的值的UDAF聚合函数:
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF2 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80,89",
"2,zhangsan,History,87,88",
"3,zhangsan,Chinese,88,87",
"4,zhangsan,Chemistry,96,95",
"5,lisi,English,70,75",
"6,lisi,Chinese,74,67",
"7,lisi,History,75,80",
"8,lisi,Chemistry,77,70",
"9,lisi,Physics,79,80",
"10,lisi,Biology,82,83",
"11,wangwu,English,96,84",
"12,wangwu,Chinese,98,64",
"13,wangwu,History,91,92",
"14,zhaoliu,English,68,80",
"15,zhaoliu,Chinese,66,69"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve1=Double.parseDouble(fields[3]);
Double achieve2=Double.parseDouble(fields[4]);
return RowFactory.create(id,name,subject,achieve1,achieve2);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show(); ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilMax(2,0);
sparkSession.udf().register("max_vals", udaf); Dataset<Row> rows1 = sparkSession.sql(""
+ "select name,max(achieve) as max_achieve "
+ "from "
+ "("
+ "select name,max(achieve1) achieve from user group by name "
+ "union all "
+ "select name,max(achieve2) achieve from user group by name "
+ ") t10 "
+ "group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,max_vals(achieve1,achieve2) as max_achieve from user group by name");
rows2.show();
}
}
上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilMax实现的就是一个多列分别求出各自列的最大值,再从这些列的最大值中找出最大的一个值作为返回的最大值。
MutilMax.java(udaf函数):
package com.dx.streaming.producer; import java.util.ArrayList;
import java.util.List; import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class MutilMax extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L;
private int columnSize = 1;
private Double defaultValue; public MutilMax(int columnSize, double defaultValue) {
this.columnSize = columnSize;
this.defaultValue = defaultValue;
} @Override
public StructType inputSchema() {
List<StructField> inputFields = new ArrayList<StructField>();
for (int i = 0; i < this.columnSize; i++) {
inputFields.add(DataTypes.createStructField("myinput" + i, DataTypes.DoubleType, true));
}
StructType inputSchema = DataTypes.createStructType(inputFields);
return inputSchema;
} @Override
public StructType bufferSchema() {
List<StructField> bufferFields = new ArrayList<StructField>();
for (int i = 0; i < this.columnSize; i++) {
bufferFields.add(DataTypes.createStructField("mymax" + i, DataTypes.DoubleType, true));
}
StructType bufferSchema = DataTypes.createStructType(bufferFields);
return bufferSchema;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return false;
} // 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
// 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
for (int i = 0; i < this.columnSize; i++) {
buffer.update(i, 0d);
}
} /**
* partitions内部combine
*/
// 用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
for (int i = 0; i < this.columnSize; i++) {
if( buffer.getDouble(i) >input.getDouble(i)){
buffer.update(i, buffer.getDouble(i));
}else{
buffer.update(i, input.getDouble(i));
}
}
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
*/
// 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
// 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
for (int i = 0; i < this.columnSize; i++) {
if( buffer1.getDouble(i) >buffer2.getDouble(i)){
buffer1.update(i, buffer1.getDouble(i));
}else{
buffer1.update(i, buffer2.getDouble(i));
}
}
} // 计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double max = Double.MIN_VALUE;
for (int i = 0; i < this.columnSize; i++) {
if (buffer.getDouble(i) > max) {
max = buffer.getDouble(i);
}
} if (max == Double.MIN_VALUE) {
max = this.defaultValue;
} return max;
} }
打印结果:
+---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+ +--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+ +--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+
Spark编写Agg函数
实现一个avg函数:
第一步:定义一个Average,用来存储count,sum;
import java.io.Serializable; public class Average implements Serializable {
private long sum;
private long count; // Constructors, getters, setters...
public long getSum() {
return sum;
} public void setSum(long sum) {
this.sum = sum;
} public long getCount() {
return count;
} public void setCount(long count) {
this.count = count;
} public Average() { } public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}
}
第二步:定义一个Employee,存储员工信息:员工名称、员工薪资;
import java.io.Serializable; public class Employee implements Serializable {
private String name;
private long salary; // Constructors, getters, setters...
public String getName() {
return name;
} public void setName(String name) {
this.name = name;
} public long getSalary() {
return salary;
} public void setSalary(long salary) {
this.salary = salary;
} public Employee() {
} public Employee(String name, long salary) {
this.name = name;
this.salary = salary;
}
}
第三步:定义一个Agg,实现对员工的薪资avg功能;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator; public class MyAverage extends Aggregator<Employee, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
@Override
public Average zero() {
return new Average(0L, 0L);
} // Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
@Override
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
} // Merge two intermediate values
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
} // Transform the output of the reduction
@Override
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
} // Specifies the Encoder for the intermediate value type
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
} // Specifies the Encoder for the final output value type
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
第四步:spark调用agg,验证。
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.*;; import java.util.ArrayList;
import java.util.List; public class SparkClient {
public static void main(String[] args) {
final SparkSession spark = SparkSession.builder().master("local[*]").appName("test_agg").getOrCreate();
final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext()); List<Employee> employeeList = new ArrayList<Employee>();
employeeList.add(new Employee("Michael", 3000L));
employeeList.add(new Employee("Andy", 4500L));
employeeList.add(new Employee("Justin", 3500L));
employeeList.add(new Employee("Berta", 4000L)); JavaRDD<Employee> rows = ctx.parallelize(employeeList);
Dataset<Employee> ds = spark.createDataFrame(rows, Employee.class).map(new MapFunction<Row, Employee>() {
@Override
public Employee call(Row row) throws Exception {
return new Employee(row.getString(0), row.getLong(1));
}
}, Encoders.bean(Employee.class)); ds.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+ MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
}
}
输出:
+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+ +--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+
参考:
https://www.cnblogs.com/LHWorldBlog/p/8432210.html
https://blog.****.net/kwu_ganymede/article/details/50462020
https://my.oschina.net/cloudcoder/blog/640009
https://blog.****.net/xgjianstart/article/details/54956413