Flink UDF
概述
- 什么是UDF
- UDF是User-defined Functions的縮寫,即自定義函數(shù)。
- UDF種類
- UDF分為三種:Scalar Functions、Table Functions、Aggregation Functions
- Scalar Functions
- Table Functions
- 和上面的Scalar Functions接收的參數(shù)個數(shù)一樣,不同的是可以返回多行,而不是單個值
- Aggregation Functions
- 從名字就可以看出來,這個是搭配GROUP BY一起使用的,將表的一個或多個列的一行或多行數(shù)據(jù)匯聚到一個值里面,看上去有點(diǎn)拗口,其實(shí)可以把它簡單理解為SQL中的聚合函數(shù)
- Table Aggregation Functions
- 相當(dāng)于Table Functions和Aggregation Functions的結(jié)合體,聚合之后,再返回多行多列
- 為什么要有UDF
- Flink SQL目前提供了很多的內(nèi)置UDF,主要是為了大家更方便的編寫SQL代碼完成自己的業(yè)務(wù)邏輯,具體內(nèi)置的UDF可以參考官方文檔;同時,F(xiàn)link 也支持注冊自己的UDF,下面正式開始我們今天的UDF探索之旅。
Scalar Functions
//不墨跡,我們直接貼代碼
package udf;
import org.apache.flink.table.functions.ScalarFunction;
public class TestScalarFunc extends ScalarFunction {
private int factor = 2020;
//和傳入數(shù)據(jù)進(jìn)行計(jì)算的邏輯,參數(shù)個數(shù)任意
public int eval() {
return factor;
}
public int eval(int a) {
return a * factor;
}
public int eval(int... a) {
int res = 1;
for (int i : a) {
res *= i;
}
return res * factor;
}
}
- 自定義Scalar Functions,需要繼承
ScalarFunction,并且有一個public的eval(),方法可以接受任意個數(shù)參數(shù),同時也可以在一個類中重載eval()
- 寫完UDF之后需要注冊到我們的運(yùn)行環(huán)境中,使用姿勢有兩種:
tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
tEnv.registerFunction("test",new TestScalarFunc());
- 第一種偏向在純SQL的環(huán)境中使用,比如我們有個Flink SQL的提交平臺,只支持純SQL語句,那我們可以把自己寫的UDF打包上傳到平臺后,通過SQL語句
CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'來創(chuàng)建UDF;同時可以把UDF注冊到catalog中,這里先不深入討論,之后我們說到Flink X Hive的時候再聊吧
- 第二種注冊方式,如果我們的類有構(gòu)造方法,可以通過new 對象的時候傳遞變量進(jìn)去,更為靈活一點(diǎn)
Table Functions
package udf;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.calcite.shaded.com.google.common.base.Strings;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;
public class TestTableFunction extends TableFunction {
private String separator = ",";
public TestTableFunction(String separator) {
this.separator = separator;
}
//和傳入數(shù)據(jù)進(jìn)行計(jì)算的邏輯,參數(shù)個數(shù)任意
public void eval(String input){
Row row = null;
if (Strings.isNullOrEmpty(input)){
row = new Row(2);
row.setField(0,null);
row.setField(1,0);
collect(row);
}else {
String[] split = input.split(separator);
for (String word : split) {
row = new Row(2);
row.setField(0,word);
row.setField(1,word.length());
collect(row);
}
}
}
@Override
public TypeInformation getResultType() {
return Types.ROW(Types.STRING,Types.INT);
}
}
- 自定義Table Functions,需要繼承
TableFunction,并且有一個public的eval(),方法可以接受任意個數(shù)參數(shù),同時也可以在一個類中重載eval()
- 因?yàn)榉祷氐氖?code>Row類型,所以需要重寫
getResultType()
- 在SQL語句中使用時,有兩種寫法:
select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)
select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE
- 第一種的用法相當(dāng)于用的是
CROSS JOIN
- 第二種的用法是
LEFT JOIN
Aggregation Functions
package udf;
import org.apache.flink.table.functions.AggregateFunction;
import java.util.Iterator;
public class TestAggregateFunction extends AggregateFunction<Long, TestAggregateFunction.SumAll> {
//返回最終結(jié)果
@Override
public Long getValue(SumAll acc) {
return acc.sum;
}
//構(gòu)建保存中間結(jié)果的對象
@Override
public SumAll createAccumulator() {
return new SumAll();
}
//和傳入數(shù)據(jù)進(jìn)行計(jì)算的邏輯
public void accumulate(SumAll acc, long iValue) {
acc.sum += iValue;
}
//減去要撤回的值
public void retract(SumAll acc, long iValue) {
acc.sum -= iValue;
}
//從每個分區(qū)把數(shù)據(jù)取出來然后合并
public void merge(SumAll acc, Iterable<SumAll> it) {
Iterator<SumAll> iter = it.iterator();
while (iter.hasNext()) {
SumAll a = iter.next();
acc.sum += a.sum;
}
}
//重置內(nèi)存中值時調(diào)用
public void resetAccumulator(SumAll acc) {
acc.sum = 0L;
}
public static class SumAll {
public long sum = 0;
}
}
- 自定義Aggregation Functions,需要繼承
AggregateFunction,并且必須要有 以下的方法
-
createAccumulator() 創(chuàng)建一個保留中間結(jié)果的數(shù)據(jù)結(jié)構(gòu)
-
accumulate() 把每個輸入行與中間結(jié)果進(jìn)行計(jì)算,可以重載
-
getValue() 獲取最終結(jié)果
- 根據(jù)不同的使用情況,還需要以下的方法
-
retract() 用于bounded OVER窗口,即窗口有結(jié)束時間
-
merge()用于多次批量聚合和會話窗口合并
-
resetAccumulator()用于多次批量聚合時,清空中間結(jié)果
Table Aggregation Functions
package udf;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
public class TestTableAggregateFunction extends TableAggregateFunction<Row,TestTableAggregateFunction.Top2> {
//創(chuàng)建保留中間結(jié)果的對象
@Override
public Top2 createAccumulator() {
Top2 t = new Top2();
t.f1 = Integer.MIN_VALUE;
t.f2 = Integer.MIN_VALUE;
return t;
}
//與傳入值進(jìn)行計(jì)算的方法
public void accumulate(Top2 t, Integer v) {
//如果傳入的值比內(nèi)存中第一個值大,那就用第一個值替換第二個值,傳入的值替換第一個值;
//如果傳入的值比第二個值大比第一個小,那么就替換第二個值。
if (v > t.f1) {
t.f2 = t.f1;
t.f1 = v;
} else if (v > t.f2) {
t.f2 = v;
}
}
//合并分區(qū)的值
public void merge(Top2 t, Iterable<Top2> iterable) {
for (Top2 otherT : iterable) {
accumulate(t, otherT.f1);
accumulate(t, otherT.f2);
}
}
//拿到返回結(jié)果的方法
public void emitValue(Top2 t, Collector<Row> out) {
Row row = null;
//發(fā)射數(shù)據(jù)
//如果第一個值不是最小的int值,那就發(fā)出去
//如果第二個值不是最小的int值,那就發(fā)出去
if (t.f1 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.f1);
row.setField(1,1);
out.collect(row);
}
if (t.f2 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.f2);
row.setField(1,2);
out.collect(row);
}
}
//撤回流拿結(jié)果的方法,會發(fā)射撤回?cái)?shù)據(jù)
public void emitUpdateWithRetract(Top2 t, RetractableCollector<Row> out) {
Row row = null;
//如果新舊值不相等,才需要撤回,不然沒必要
//如果舊值不等于int最小值,說明之前發(fā)射過數(shù)據(jù),需要撤回
//然后將新值發(fā)射出去
if (!t.f1.equals(t.oldF1)) {
if (t.oldF1 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.oldF1);
row.setField(1,1);
out.retract(row);
}
row = new Row(2);
row.setField(0,t.f1);
row.setField(1,1);
out.collect(row);
t.oldF1 = t.f1;
}
//和上面邏輯一樣,只是一個發(fā)射f1,一個f2
if (!t.f2.equals(t.oldF2)) {
// if there is an update, retract old value then emit new value.
if (t.oldF2 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.oldF2);
row.setField(1,2);
out.retract(row);
}
row = new Row(2);
row.setField(0,t.f2);
row.setField(1,2);
out.collect(row);
t.oldF2 = t.f2;
}
}
//保留中間結(jié)果的類
public class Top2{
public Integer f1;
public Integer f2;
public Integer oldF1;
public Integer oldF2;
}
@Override
public TypeInformation<Row> getResultType() {
return Types.ROW(Types.INT,Types.INT);
}
}
- 自定義Table Aggregation Functions,需要繼承
TableAggregateFunction,并且必須要有 以下的方法
-
createAccumulator() 創(chuàng)建一個保留中間結(jié)果的數(shù)據(jù)結(jié)構(gòu)
-
accumulate() 把每個輸入行與中間結(jié)果進(jìn)行計(jì)算,可以重載
- 根據(jù)不同的使用情況,還需要以下的方法
-
retract() 用于bounded OVER窗口,即窗口有結(jié)束時間
-
merge()用于多次批量聚合和會話窗口合并
-
resetAccumulator()用于多次批量聚合時,清空中間結(jié)果
-
emitValue() 用于批量和窗口聚合拿到結(jié)果
-
emitUpdateWithRetract() 用于流式計(jì)算的撤回流
- 目前Table Aggregation Functions只支持在Table Api中使用
完整代碼
//下面貼出來的是主類的代碼,具體每個UDF的類上面已經(jīng)有了
package FlinkSql;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import udf.TestAggregateFunction;
import udf.TestScalarFunc;
import udf.TestTableAggregateFunction;
import udf.TestTableFunction;
import static util.FlinkConstant.env;
import static util.FlinkConstant.tEnv;
public class FlinkSql04 {
public static void main(String[] args) throws Exception {
DataStream<Row> source = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row = new Row(3);
row.setField(0, 2);
row.setField(1, 3);
row.setField(2, 3);
ctx.collect(row);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.INT,Types.INT,Types.INT));
tEnv.createTemporaryView("t",source,"a,b,c");
// tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
tEnv.registerFunction("test",new TestScalarFunc());
Table table = tEnv.sqlQuery("select test() as a,test(a) as b, test(a,b,c) as c from t");
DataStream<Row> res = tEnv.toAppendStream(table, Row.class);
// res.print().name("Scalar Functions Print").setParallelism(1);
DataStream<Row> ds2 = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row = new Row(2);
row.setField(0, 22);
row.setField(1, "aa,b,cdd,dfsfdg,exxxxx");
ctx.collect(row);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.INT, Types.STRING));
tEnv.createTemporaryView("t2",ds2,"age,name_list");
tEnv.registerFunction("test2",new TestTableFunction(","));
// Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)");
Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE");
DataStream<Row> res2 = tEnv.toAppendStream(table2, Row.class);
// res2.print().name("Table Functions Print").setParallelism(1);
DataStream<Row> ds3 = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row1 = new Row(2);
row1.setField(0,"a");
row1.setField(1,1L);
Row row2 = new Row(2);
row2.setField(0,"a");
row2.setField(1,2L);
Row row3 = new Row(2);
row3.setField(0,"b");
row3.setField(1,100L);
ctx.collect(row1);
ctx.collect(row2);
ctx.collect(row3);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.STRING, Types.LONG));
tEnv.createTemporaryView("t3",ds3,"name,cnt");
tEnv.registerFunction("test3",new TestAggregateFunction());
Table table3 = tEnv.sqlQuery("select name,test3(cnt) as mySum from t3 group by name");
DataStream<Tuple2<Boolean, Row>> res3 = tEnv.toRetractStream(table3, Row.class);
// res3.print().name("Aggregate Functions Print").setParallelism(1);
DataStream<Row> ds4 = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row1 = new Row(2);
row1.setField(0,"a");
row1.setField(1,1);
Row row2 = new Row(2);
row2.setField(0,"a");
row2.setField(1,2);
Row row3 = new Row(2);
row3.setField(0,"a");
row3.setField(1,100);
ctx.collect(row1);
ctx.collect(row2);
ctx.collect(row3);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.STRING, Types.INT));
tEnv.createTemporaryView("t4",ds4,"name,cnt");
tEnv.registerFunction("test4",new TestTableAggregateFunction());
Table table4 = tEnv.sqlQuery("select * from t4");
Table table5 = table4.groupBy("name")
.flatAggregate("test4(cnt) as (v,rank)")
.select("name,v,rank");
DataStream<Tuple2<Boolean, Row>> res4 = tEnv.toRetractStream(table5, Row.class);
res4.print().name("Aggregate Functions Print").setParallelism(1);
env.execute("test udf");
}
}