Flink Sql教程(4)

Flink UDF

概述

  • 什么是UDF
    • UDF是User-defined Functions的縮寫,即自定義函數(shù)。
  • UDF種類
    • UDF分為三種:Scalar Functions、Table Functions、Aggregation Functions
    • Scalar Functions
      • 接收0、1、多個參數(shù),返回一個值
    • Table Functions
      • 和上面的Scalar Functions接收的參數(shù)個數(shù)一樣,不同的是可以返回多行,而不是單個值
    • Aggregation Functions
      • 從名字就可以看出來,這個是搭配GROUP BY一起使用的,將表的一個或多個列的一行或多行數(shù)據(jù)匯聚到一個值里面,看上去有點(diǎn)拗口,其實(shí)可以把它簡單理解為SQL中的聚合函數(shù)
    • Table Aggregation Functions
      • 相當(dāng)于Table Functions和Aggregation Functions的結(jié)合體,聚合之后,再返回多行多列
  • 為什么要有UDF
    • Flink SQL目前提供了很多的內(nèi)置UDF,主要是為了大家更方便的編寫SQL代碼完成自己的業(yè)務(wù)邏輯,具體內(nèi)置的UDF可以參考官方文檔;同時,F(xiàn)link 也支持注冊自己的UDF,下面正式開始我們今天的UDF探索之旅。

Scalar Functions

    //不墨跡,我們直接貼代碼
    package udf;

    import org.apache.flink.table.functions.ScalarFunction;
    
    
    public class TestScalarFunc extends ScalarFunction {
    
        private int factor = 2020;
        //和傳入數(shù)據(jù)進(jìn)行計(jì)算的邏輯,參數(shù)個數(shù)任意
        public int eval() {
            return factor;
        }
    
        public int eval(int a) {
            return a * factor;
        }
    
        public int eval(int... a) {
            int res = 1;
            for (int i : a) {
                res *= i;
            }
            return res * factor;
        }
}

  • 自定義Scalar Functions,需要繼承ScalarFunction,并且有一個publiceval(),方法可以接受任意個數(shù)參數(shù),同時也可以在一個類中重載eval()
  • 寫完UDF之后需要注冊到我們的運(yùn)行環(huán)境中,使用姿勢有兩種:
    • tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
    • tEnv.registerFunction("test",new TestScalarFunc());
    • 第一種偏向在純SQL的環(huán)境中使用,比如我們有個Flink SQL的提交平臺,只支持純SQL語句,那我們可以把自己寫的UDF打包上傳到平臺后,通過SQL語句CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'來創(chuàng)建UDF;同時可以把UDF注冊到catalog中,這里先不深入討論,之后我們說到Flink X Hive的時候再聊吧
    • 第二種注冊方式,如果我們的類有構(gòu)造方法,可以通過new 對象的時候傳遞變量進(jìn)去,更為靈活一點(diǎn)

Table Functions

    package udf;

    import org.apache.flink.api.common.typeinfo.TypeInformation;
    import org.apache.flink.api.common.typeinfo.Types;
    import org.apache.flink.calcite.shaded.com.google.common.base.Strings;
    import org.apache.flink.table.functions.TableFunction;
    import org.apache.flink.types.Row;
    
    public class TestTableFunction extends TableFunction {
    
        private String separator = ",";
    
        public TestTableFunction(String separator) {
            this.separator = separator;
        }

        //和傳入數(shù)據(jù)進(jìn)行計(jì)算的邏輯,參數(shù)個數(shù)任意
        public void eval(String input){
            
            Row row = null;
            
            if (Strings.isNullOrEmpty(input)){
                
                row = new Row(2);
                row.setField(0,null);
                row.setField(1,0);
                collect(row);
                
            }else {
                
                String[] split = input.split(separator);
                
                for (String word : split) {
                    row = new Row(2);
                    row.setField(0,word);
                    row.setField(1,word.length());
                    collect(row);
                }
                
            }
    
        }
    
        @Override
        public TypeInformation getResultType() {
            return Types.ROW(Types.STRING,Types.INT);
        }
    }

  • 自定義Table Functions,需要繼承TableFunction,并且有一個publiceval(),方法可以接受任意個數(shù)參數(shù),同時也可以在一個類中重載eval()
  • 因?yàn)榉祷氐氖?code>Row類型,所以需要重寫getResultType()
  • 在SQL語句中使用時,有兩種寫法:
    • select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)
    • select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE
    • 第一種的用法相當(dāng)于用的是CROSS JOIN
    • 第二種的用法是LEFT JOIN

Aggregation Functions

    package udf;

    import org.apache.flink.table.functions.AggregateFunction;
    
    import java.util.Iterator;
    
    public class TestAggregateFunction extends AggregateFunction<Long, TestAggregateFunction.SumAll> {
        //返回最終結(jié)果
        @Override
        public Long getValue(SumAll acc) {
            return acc.sum;
        }
        //構(gòu)建保存中間結(jié)果的對象
        @Override
        public SumAll createAccumulator() {
            return new SumAll();
        }
        //和傳入數(shù)據(jù)進(jìn)行計(jì)算的邏輯
        public void accumulate(SumAll acc, long iValue) {
            acc.sum += iValue;
        }
    
        //減去要撤回的值
        public void retract(SumAll acc, long iValue) {
            acc.sum -= iValue;
        }
        
        //從每個分區(qū)把數(shù)據(jù)取出來然后合并
        public void merge(SumAll acc, Iterable<SumAll> it) {
    
            Iterator<SumAll> iter = it.iterator();
    
            while (iter.hasNext()) {
                SumAll a = iter.next();
                acc.sum += a.sum;
    
            }
        }
        //重置內(nèi)存中值時調(diào)用
        public void resetAccumulator(SumAll acc) {
            acc.sum = 0L;
        }
    
        public static class SumAll {
            public long sum = 0;
        }
    
    }

  • 自定義Aggregation Functions,需要繼承AggregateFunction,并且必須要有 以下的方法
    • createAccumulator() 創(chuàng)建一個保留中間結(jié)果的數(shù)據(jù)結(jié)構(gòu)
    • accumulate() 把每個輸入行與中間結(jié)果進(jìn)行計(jì)算,可以重載
    • getValue() 獲取最終結(jié)果
  • 根據(jù)不同的使用情況,還需要以下的方法
    • retract() 用于bounded OVER窗口,即窗口有結(jié)束時間
    • merge()用于多次批量聚合和會話窗口合并
    • resetAccumulator()用于多次批量聚合時,清空中間結(jié)果

Table Aggregation Functions

    package udf;

    import org.apache.flink.api.common.typeinfo.TypeInformation;
    import org.apache.flink.api.common.typeinfo.Types;
    import org.apache.flink.table.functions.TableAggregateFunction;
    import org.apache.flink.types.Row;
    import org.apache.flink.util.Collector;
    
    public class TestTableAggregateFunction extends TableAggregateFunction<Row,TestTableAggregateFunction.Top2> {
        //創(chuàng)建保留中間結(jié)果的對象
        @Override
        public Top2 createAccumulator() {
            Top2 t = new Top2();
            t.f1 = Integer.MIN_VALUE;
            t.f2 = Integer.MIN_VALUE;
    
            return t;
        }
        //與傳入值進(jìn)行計(jì)算的方法
        public void accumulate(Top2 t, Integer v) {
            //如果傳入的值比內(nèi)存中第一個值大,那就用第一個值替換第二個值,傳入的值替換第一個值;
            //如果傳入的值比第二個值大比第一個小,那么就替換第二個值。
            if (v > t.f1) {
                t.f2 = t.f1;
                t.f1 = v;
            } else if (v > t.f2) {
                t.f2 = v;
            }
        }
        
        //合并分區(qū)的值
        public void merge(Top2 t, Iterable<Top2> iterable) {
            for (Top2 otherT : iterable) {
                accumulate(t, otherT.f1);
                accumulate(t, otherT.f2);
            }
        }
    
        //拿到返回結(jié)果的方法
        public void emitValue(Top2 t, Collector<Row> out) {
            Row row = null;
            //發(fā)射數(shù)據(jù)
            //如果第一個值不是最小的int值,那就發(fā)出去
            //如果第二個值不是最小的int值,那就發(fā)出去
            if (t.f1 != Integer.MIN_VALUE) {
                row = new Row(2);
                row.setField(0,t.f1);
                row.setField(1,1);
                out.collect(row);
            }
            if (t.f2 != Integer.MIN_VALUE) {
                row = new Row(2);
                row.setField(0,t.f2);
                row.setField(1,2);
                out.collect(row);
            }
        }
        //撤回流拿結(jié)果的方法,會發(fā)射撤回?cái)?shù)據(jù)
        public void emitUpdateWithRetract(Top2 t, RetractableCollector<Row> out) {
            Row row = null;
            //如果新舊值不相等,才需要撤回,不然沒必要
            //如果舊值不等于int最小值,說明之前發(fā)射過數(shù)據(jù),需要撤回
            //然后將新值發(fā)射出去
            if (!t.f1.equals(t.oldF1)) {
                if (t.oldF1 != Integer.MIN_VALUE) {
                    row = new Row(2);
                    row.setField(0,t.oldF1);
                    row.setField(1,1);
                    out.retract(row);
                }
                row = new Row(2);
                row.setField(0,t.f1);
                row.setField(1,1);
                out.collect(row);
                t.oldF1 = t.f1;
            }
            //和上面邏輯一樣,只是一個發(fā)射f1,一個f2
            if (!t.f2.equals(t.oldF2)) {
                // if there is an update, retract old value then emit new value.
                if (t.oldF2 != Integer.MIN_VALUE) {
                    row = new Row(2);
                    row.setField(0,t.oldF2);
                    row.setField(1,2);
                    out.retract(row);
                }
                row = new Row(2);
                row.setField(0,t.f2);
                row.setField(1,2);
                out.collect(row);
                t.oldF2 = t.f2;
            }
        }
        //保留中間結(jié)果的類
        public class Top2{
            public Integer f1;
            public Integer f2;
            public Integer oldF1;
            public Integer oldF2;
    
        }
    
        @Override
        public TypeInformation<Row> getResultType() {
            return Types.ROW(Types.INT,Types.INT);
        }
    }

  • 自定義Table Aggregation Functions,需要繼承TableAggregateFunction,并且必須要有 以下的方法
    • createAccumulator() 創(chuàng)建一個保留中間結(jié)果的數(shù)據(jù)結(jié)構(gòu)
    • accumulate() 把每個輸入行與中間結(jié)果進(jìn)行計(jì)算,可以重載
  • 根據(jù)不同的使用情況,還需要以下的方法
    • retract() 用于bounded OVER窗口,即窗口有結(jié)束時間
    • merge()用于多次批量聚合和會話窗口合并
    • resetAccumulator()用于多次批量聚合時,清空中間結(jié)果
    • emitValue() 用于批量和窗口聚合拿到結(jié)果
    • emitUpdateWithRetract() 用于流式計(jì)算的撤回流
  • 目前Table Aggregation Functions只支持在Table Api中使用

完整代碼

    //下面貼出來的是主類的代碼,具體每個UDF的類上面已經(jīng)有了
    package FlinkSql;


    import org.apache.flink.api.common.typeinfo.TypeInformation;
    import org.apache.flink.api.common.typeinfo.Types;
    import org.apache.flink.api.java.tuple.Tuple2;
    import org.apache.flink.streaming.api.datastream.DataStream;
    import org.apache.flink.streaming.api.datastream.DataStreamSource;
    import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
    import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
    import org.apache.flink.table.api.Table;
    import org.apache.flink.types.Row;
    import udf.TestAggregateFunction;
    import udf.TestScalarFunc;
    import udf.TestTableAggregateFunction;
    import udf.TestTableFunction;
    
    import static util.FlinkConstant.env;
    import static util.FlinkConstant.tEnv;
    
    public class FlinkSql04 {
        public static void main(String[] args) throws Exception {
    
    
            DataStream<Row> source = env.addSource(new RichSourceFunction<Row>() {
    
                @Override
                public void run(SourceContext<Row> ctx) throws Exception {
                        Row row = new Row(3);
                        row.setField(0, 2);
                        row.setField(1, 3);
                        row.setField(2, 3);
                        ctx.collect(row);
                }
    
                @Override
                public void cancel() {
    
                }
            }).returns(Types.ROW(Types.INT,Types.INT,Types.INT));
    
            tEnv.createTemporaryView("t",source,"a,b,c");
    
    //        tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
    
            tEnv.registerFunction("test",new TestScalarFunc());
    
            Table table = tEnv.sqlQuery("select test() as a,test(a) as b, test(a,b,c) as c from t");
    
            DataStream<Row> res = tEnv.toAppendStream(table, Row.class);
    
    //        res.print().name("Scalar Functions Print").setParallelism(1);
    
            DataStream<Row> ds2 = env.addSource(new RichSourceFunction<Row>() {
    
    
                @Override
                public void run(SourceContext<Row> ctx) throws Exception {
                        Row row = new Row(2);
                        row.setField(0, 22);
                        row.setField(1, "aa,b,cdd,dfsfdg,exxxxx");
                        ctx.collect(row);
                }
    
                @Override
                public void cancel() {
    
                }
            }).returns(Types.ROW(Types.INT, Types.STRING));
    
            tEnv.createTemporaryView("t2",ds2,"age,name_list");
    
            tEnv.registerFunction("test2",new TestTableFunction(","));
    
    //        Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)");
    
            Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE");
    
            DataStream<Row> res2 = tEnv.toAppendStream(table2, Row.class);
    
    //        res2.print().name("Table Functions Print").setParallelism(1);
    
            DataStream<Row> ds3 = env.addSource(new RichSourceFunction<Row>() {
                @Override
                public void run(SourceContext<Row> ctx) throws Exception {
                    Row row1 = new Row(2);
                    row1.setField(0,"a");
                    row1.setField(1,1L);
    
                    Row row2 = new Row(2);
                    row2.setField(0,"a");
                    row2.setField(1,2L);
    
                    Row row3 = new Row(2);
                    row3.setField(0,"b");
                    row3.setField(1,100L);
    
                    ctx.collect(row1);
                    ctx.collect(row2);
                    ctx.collect(row3);
    
                }
    
                @Override
                public void cancel() {
    
                }
            }).returns(Types.ROW(Types.STRING, Types.LONG));
    
            tEnv.createTemporaryView("t3",ds3,"name,cnt");
    
            tEnv.registerFunction("test3",new TestAggregateFunction());
    
            Table table3 = tEnv.sqlQuery("select name,test3(cnt) as mySum from t3 group by name");
    
            DataStream<Tuple2<Boolean, Row>> res3 = tEnv.toRetractStream(table3, Row.class);
    
    //        res3.print().name("Aggregate Functions Print").setParallelism(1);
    
            DataStream<Row> ds4 = env.addSource(new RichSourceFunction<Row>() {
                @Override
                public void run(SourceContext<Row> ctx) throws Exception {
                    Row row1 = new Row(2);
                    row1.setField(0,"a");
                    row1.setField(1,1);
    
                    Row row2 = new Row(2);
                    row2.setField(0,"a");
                    row2.setField(1,2);
    
                    Row row3 = new Row(2);
                    row3.setField(0,"a");
                    row3.setField(1,100);
    
                    ctx.collect(row1);
                    ctx.collect(row2);
                    ctx.collect(row3);
                }
    
                @Override
                public void cancel() {
    
                }
            }).returns(Types.ROW(Types.STRING, Types.INT));
    
            tEnv.createTemporaryView("t4",ds4,"name,cnt");
    
            tEnv.registerFunction("test4",new TestTableAggregateFunction());
    
            Table table4 = tEnv.sqlQuery("select * from t4");
    
            Table table5 = table4.groupBy("name")
                    .flatAggregate("test4(cnt) as (v,rank)")
                    .select("name,v,rank");
    
            DataStream<Tuple2<Boolean, Row>> res4 = tEnv.toRetractStream(table5, Row.class);
    
            res4.print().name("Aggregate Functions Print").setParallelism(1);
    
            env.execute("test udf");
    
        }
    }

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容