spark中使用自定义UDAF

时间:2022-05-13 20:56:35

 一 UDAF简介

Hive中的自定义函数UDAF
UDAF(User- Defined Aggregation Funcation),用户自定义弱类型聚合函数

所有的UDAF函数在内存里都是一块buffer(缓冲区),这个换成区被分成了多个块,每个块有一个index,从0开始。聚合一个数据时,会占用编号为0的块。

遍历表中的每一行数据,然后扔到UDAF中做聚合,先把buffer中已存的数据拿出来和新的数据做合并,然后再扔到buffer中,下次在拿一个新的数据做相同的过程。


  二 自定义去重UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

class GroupConcatDistinct extends UserDefinedAggregateFunction{

// UDAF:输入数据类型为String,如果是Long类型是LongType
override def inputSchema: StructType = StructType(StructField("cityInfo", StringType)::Nil)

// 缓冲区里数据类型,如果缓冲区内有两个属性可以定义为StructType(StructField("bufferCityInfo", StringType)::StructField("bufferNameInfo", StringType)::Nil)
  override def bufferSchema: StructType = StructType(StructField("bufferCityInfo", StringType)::Nil)

// 输出数据类型
override def dataType: DataType = StringType
//聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
//初始化缓冲区,StringType类型
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
}

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//buffer里定义了一个属性,所以使用buffer.getString(1)取到
    val bufferCityInfo = buffer.getString(0)
    val cityInfo = input.getString(0)
//先判断buffer中有没有cityInfo,没有在进行插入操作
if(!bufferCityInfo.contains(cityInfo)){
if("".equals(bufferCityInfo)){
bufferCityInfo += cityInfo
}else{
bufferCityInfo += "," + cityInfo
}
//对buffer进行更新,更新的块是0,更新的数据是bufferCityInfo
buffer.update(0, bufferCityInfo)
}
}

//将两个自定义UDAF的值汇总到一起
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// bufferCityInfo1: cityId1:cityName1, cityId2:cityName2
var bufferCityInfo1 = buffer1.getString(0)
// bufferCityInfo2: cityId1:cityName1, cityId2:cityName2
val bufferCityInfo2 = buffer2.getString(0)

for(cityInfo <- bufferCityInfo2.split(",")){
if(!bufferCityInfo1.contains(cityInfo)){
if("".equals(bufferCityInfo1)){
bufferCityInfo1 += cityInfo
}else{
bufferCityInfo1 += "," + cityInfo
}
}
}

buffer1.update(0, bufferCityInfo1)
}
//获取最后的值
override def evaluate(buffer: Row): Any = {
buffer.getString(0)
}
}
三 UDAF在spark中的注册使用

注册:(1)sparkSession.udf.register("concat_long_string", (v1:Long, v2:String, split:String) =>{v1 + split + v2})

(2)sparkSession.udf.register("group_concat_distinct", new GroupConcatDistinct)
  使用:val sql = "select area,pid,count(*) click_count,"+"group_concat_distinct(concat_long_string(city_id,city_name,':')) " +"city_infos" +" from tmp_area_basic_info group by area,pid "