java spark udf练习

package io.renren.utils.udf;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.expressions.UserDefinedFunction;

import java.io.Serializable;

import static org.apache.spark.sql.Encoders.*;

/**
 * @program: renren-cloud
 * @description:
 * @author: yyyyjinying
 * @create: 2023-06-20 14:11
 **/
public class JavaUserDefinedTypedAggregation {
    @Data
    @NoArgsConstructor
    @AllArgsConstructor
    public static class Employee implements Serializable {
        private String name;
        private long salary;
    }

    @Data
    @NoArgsConstructor
    @AllArgsConstructor
    public static class Average implements Serializable {
        private long sum;
        private long count;
    }

    public static class MyAverage extends Aggregator<Employee, Average, Double> {

        @Override
        public Average zero() {
            return new Average(0L, 0L);
        }

        @Override
        public Average reduce(Average buffer, Employee employee) {
            long newSum = buffer.getSum() + employee.getSalary();
            long newCount = buffer.getCount() + 1;
            buffer.setSum(newSum);
            buffer.setCount(newCount);
            return buffer;
        }

        @Override
        public Average merge(Average b1, Average b2) {
            long mergedSum = b1.getSum() + b2.getSum();
            long mergedCount = b1.getCount() + b2.getCount();
            b1.setSum(mergedSum);
            b1.setCount(mergedCount);
            return b1;
        }

        @Override
        public Double finish(Average reduction) {
            return ((double) reduction.getSum()) / reduction.getCount();
        }

        @Override
        public Encoder<Average> bufferEncoder() {
            return bean(Average.class);
        }

        @Override
        public Encoder<Double> outputEncoder() {
            return DOUBLE();
        }
    }

    public static class MyUnAverage extends Aggregator<Long, Average, Double> {
        // A zero value for this aggregation. Should satisfy the property that any b + zero = b
        @Override
        public Average zero() {
            return new Average(0L, 0L);
        }

        // Combine two values to produce a new value. For performance, the function may modify `buffer`
        // and return it instead of constructing a new object
        @Override
        public Average reduce(Average buffer, Long data) {
            long newSum = buffer.getSum() + data;
            long newCount = buffer.getCount() + 1;
            buffer.setSum(newSum);
            buffer.setCount(newCount);
            return buffer;
        }

        // Merge two intermediate values
        @Override
        public Average merge(Average b1, Average b2) {
            long mergedSum = b1.getSum() + b2.getSum();
            long mergedCount = b1.getCount() + b2.getCount();
            b1.setSum(mergedSum);
            b1.setCount(mergedCount);
            return b1;
        }

        // Transform the output of the reduction
        @Override
        public Double finish(Average reduction) {
            return ((double) reduction.getSum()) / reduction.getCount();
        }

        // Specifies the Encoder for the intermediate value type
        @Override
        public Encoder<Average> bufferEncoder() {
            return Encoders.bean(Average.class);
        }

        // Specifies the Encoder for the final output value type
        @Override
        public Encoder<Double> outputEncoder() {
            return Encoders.DOUBLE();
        }
    }

    public static SparkSession getSpark() {
        return SparkSession
                .builder()
                .appName("spark udaf example")
                .master("local[*]")
                .config("dfs.client.use.datanode.hostname", true)
                .getOrCreate();
    }

    public static void averageTypeUdf() {
        SparkSession spark = getSpark();
        // $example on:typed_custom_aggregation$
        Encoder<Employee> employeeEncoder = bean(Employee.class);
        String path = "renren-admin\\renren-admin-server\\src\\main\\resources\\employees.json";
        Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
        ds.show();
//        +-------+------+
//        |   name|salary|
//        +-------+------+
//        |Michael|  3000|
//        |   Andy|  4500|
//        | Justin|  3500|
//        |  Berta|  4000|
//        +-------+------+
        MyAverage myAverage = new MyAverage();
        // Convert the function to a `TypedColumn` and give it a name
        TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
        Dataset<Double> result = ds.select(averageSalary);
        result.show();
        //        +--------------+
        //        |average_salary|
        //        +--------------+
        //        |        3750.0|
        //        +--------------+
        spark.stop();
    }

    public static void averageUnTypeUdf() {
        SparkSession spark = getSpark();
        String path = "renren-admin\\renren-admin-server\\src\\main\\resources\\employees.json";
        Dataset<Row> df = spark.read().json(path);
        df.createOrReplaceTempView("employees");
        df.show();

        spark.udf().register("myAverage", functions.udaf(new MyUnAverage(), LONG()));
        spark.sql("SELECT myAverage(salary) as average_salary FROM employees").show();
    }

    public static void main(String[] args) {
        // 强类型聚合平均值
//        averageTypeUdf();

//        不安全类型聚合平均值
        averageUnTypeUdf();


    }

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值