spark中使用UDAF自定义sql函数-scala
1. 开发环境说明
IntelliJ IDEA 2019.1.1 (Ultimate Edition)
JRE: 1.8.0_202-release-1483-b44 x86_64
JVM: OpenJDK 64-Bit Server VM by JetBrains s.r.o
macOS 10.14.4
ProjectStruct->Libraries: spark-2.3.3-bin-hadoop2.7
Global Libraries: scala-2.11.11
2. 开发测试需求说明
对数据中重复的姓名进行统计,统计后的结果显示为新的数据字段
3. 代码实现如下
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType}
class StringCount extends UserDefinedAggregateFunction{
//输入数据的数据类型
override def inputSchema: StructType = StructType(Array(StructField("string",StringType,true)))
//中间进行聚合操作的数据类型-缓存区数据结构
override def bufferSchema: StructType = StructType(Array(StructField("Count",IntegerType,true)))
//函数返回值的类型
//特别注册如果此给他的是StructType将会报出一个错误,详见: 5. 错误说明
override def dataType: DataType = IntegerType
//给定的一组输入,UDAF是否总是生成相同的结果, 确保一致性, 一般是true
override def deterministic: Boolean = true
//每个分组的数据进行初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0
//有新的数据进入,进行的聚集合值计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = buffer(0)=buffer.getAs[Int](0)+1
//spark多个节点计算后,需要进行合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
//最终的聚合值
override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}
object StringCount {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("UDAF_StringCount").setMaster("local")
val ss = new SparkContext(conf)
val spark= SparkSession.builder().config(conf).getOrCreate()
//模拟测试数据
val testData = Array("leo","jack","xuan","lee","xuan","john","leo","leo","lee")
val testDataRowRDD = ss.parallelize(testData,2).map(Row(_))
val testSchema = StructType(Array(StructField("name",StringType,true)))
val testDF = spark.createDataFrame(testDataRowRDD,testSchema)
testDF.createOrReplaceTempView("testTableView")
//将StringCount类,注册为自定义函数
spark.udf.register("stringCount",new StringCount)
val resDF = spark.sql("select name,stringCount(name) as name_Count from testTableView group by name")
resDF.show()
}
}
4.执行结果如下
+----+----------+
|name|name_Count|
+----+----------+
|jack| 1|
|john| 1|
| leo| 3|
| lee| 2|
|xuan| 2|
+----+----------+
5. 错误说明
override def dataType: DataType = IntegerType
使用下面的方式来替换后 ,1.5和之前版本是可以正常运行的,
override def dataType: DataType = StructType(Array(StructField("Result",IntegerType,true)))
当前版本(2.x)的开发环境,如果替换后, 在运行进会出现如下的错误:
scala.MatchError: 1 (of class java.lang.Integer)
Driver stacktrace:
2019-05-20 15:08:02 INFO DAGScheduler:54 - Job 3 failed: show at UDAF_StringCount.scala:53, took 2.079227 s
Exception in thread "main" org.apache.spark.SparkException: Job aborted due to stage failure: Task 42 in stage 7.0 failed 1 times, most recent failure: Lost task 42.0 in stage 7.0 (TID 125, localhost, executor driver): scala.MatchError: 1 (of class java.lang.Integer)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:236)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:231)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:379)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.eval(udaf.scala:444)
6. 分享手绘一幅
it’s created by my daughter Grace.