概述
弱类型UDAF
弱类型UDAF,非常类似RDD的combineByKeyWithClassTag算子 弱类型UDAF继承实现 UserDefinedAggregateFunction 抽象类
def combineByKey[C](
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true,
serializer: Serializer = null): RDD[(K, C)] = self.withScope {
combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners,
partitioner, mapSideCombine, serializer)(null)
}
- override def inputSchema: StructType = 输入schema
- override def bufferSchema: StructType = 聚合过程schema
- override def dataType: DataType = 返回值类型
- override def deterministic: Boolean = 是否固定返回值类型
- override def initialize(buffer: MutableAggregationBuffer): Unit = 初始函数用来初始化基准值
- override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 分区内如何聚合
- override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 分区之间如何聚合
- override def evaluate(buffer: Row): Any = 聚合结果计算
一个自定义求平均数UDAF例子
object UDAFApp extends App{
val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate();
import spark.implicits._;
val df = spark.read.format("json").load("D:\\data\\employees.json")
//UDAF函数注册 只有UserDefinedAggregateFunction才能为SQL注册函数
spark.udf.register("cusAvg",MyAvgUDAF)
//DF转临时视图
df.createTempView("employees_view")
spark.sql("select cusAvg(salary) as salary from employees_view").show();
//df-api形式
df.select(MyAvgUDAF.apply($"salary")).show()
spark.close()
}
object MyAvgUDAF extends UserDefinedAggregateFunction
{
//输入schema
override def inputSchema: StructType = StructType(StructField("input",DoubleType)::Nil);
//聚合过程schema
override def bufferSchema: StructType = StructType(StructField("Sum",DoubleType)::StructField("Count",LongType)::Nil)
//返回值类型
override def dataType: DataType = DoubleType
//是否固定返回值类型
override def deterministic: Boolean = true
//初始化函数
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//设定聚合基准初始值 combineByKeyWithClassTag算子 createCombiner的部分
buffer(0) = 0D; //总和0
buffer(1) = 0L; //个数0
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//行第一列(Row[0])是否为null
if(!input.isNullAt(0)){
//combineByKeyWithClassTag算子....mergeValue 部分
buffer(0)= buffer.getDouble(0)+ input.getDouble(0);
buffer(1) =buffer.getLong(1)+1;
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//combineByKeyWithClassTag算子....mergeCombiners 部分
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0);
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1);
}
override def evaluate(buffer: Row): Any = buffer.getDouble(0) /buffer.getLong(1) ;
}
强类型UDAF
自定义强类型UDAF 基础实现类 Aggregator
注意:
这种定义方式不能在UDF中注册,也不能用在SQL中
一个强类型UDAF定义如下
object UDAFApp extends App{
val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate();
import spark.implicits._;
val ds = spark.read.format("json").load("D:\\data\\employees.json").as[Employee]
//ds-api形式
ds.select(MyAverage.toColumn.name("salary")).show()
spark.close()
}
//目标类型定义
case class Employee(val name: String,val salary: Long)
//聚合类型定义
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
override def zero: Average = Average(0,0)
override def reduce(b: Average, a: Employee): Average = {
b.sum += a.salary;
b.count += 1
b
}
override def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum;
b1.count += b2.count;
b1;
}
override def finish(reduction: Average): Double = {
println(reduction.sum + " "+ reduction.count)
reduction.sum.toDouble/reduction.count
}
override def bufferEncoder: Encoder[Average] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}