package io.renren.utils.udf;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import java.io.Serializable;
import static org.apache.spark.sql.Encoders.*;
/**
* @program: renren-cloud
* @description:
* @author: yyyyjinying
* @create: 2023-06-20 14:11
**/
public class JavaUserDefinedTypedAggregation {
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class Employee implements Serializable {
private String name;
private long salary;
}
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class Average implements Serializable {
private long sum;
private long count;
}
public static class MyAverage extends Aggregator<Employee, Average, Double> {
@Override
public Average zero() {
return new Average(0L, 0L);
}
@Override
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
@Override
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
@Override
public Encoder<Average> bufferEncoder() {
return bean(Average.class);
}
@Override
public Encoder<Double> outputEncoder() {
return DOUBLE();
}
}
public static class MyUnAverage extends Aggregator<Long, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
@Override
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
@Override
public Average reduce(Average buffer, Long data) {
long newSum = buffer.getSum() + data;
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Transform the output of the reduction
@Override
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// Specifies the Encoder for the intermediate value type
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// Specifies the Encoder for the final output value type
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
public static SparkSession getSpark() {
return SparkSession
.builder()
.appName("spark udaf example")
.master("local[*]")
.config("dfs.client.use.datanode.hostname", true)
.getOrCreate();
}
public static void averageTypeUdf() {
SparkSession spark = getSpark();
// $example on:typed_custom_aggregation$
Encoder<Employee> employeeEncoder = bean(Employee.class);
String path = "renren-admin\\renren-admin-server\\src\\main\\resources\\employees.json";
Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
ds.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
spark.stop();
}
public static void averageUnTypeUdf() {
SparkSession spark = getSpark();
String path = "renren-admin\\renren-admin-server\\src\\main\\resources\\employees.json";
Dataset<Row> df = spark.read().json(path);
df.createOrReplaceTempView("employees");
df.show();
spark.udf().register("myAverage", functions.udaf(new MyUnAverage(), LONG()));
spark.sql("SELECT myAverage(salary) as average_salary FROM employees").show();
}
public static void main(String[] args) {
// 强类型聚合平均值
// averageTypeUdf();
// 不安全类型聚合平均值
averageUnTypeUdf();
}
}
java spark udf练习
最新推荐文章于 2024-05-31 10:46:10 发布