1)总结
通过 RelBuilder 创建 RelNode。
2)代码示例
MyRelBuilder
import cn.com.ptpress.cdm.ds.csv.CsvSchema;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelTraitDef;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.RelBuilder;
import org.junit.jupiter.api.Test;
import java.util.List;
class MyRelBuilder {
/**
* LogicalJoin(condition=[=($2, $8)], joinType=[inner])
* LogicalJoin(condition=[=($0, $3)], joinType=[inner])
* LogicalTableScan(table=[[csv, data]])
* LogicalTableScan(table=[[csv, data]])
* LogicalJoin(condition=[=($1, $4)], joinType=[inner])
* LogicalTableScan(table=[[csv, data]])
* LogicalTableScan(table=[[csv, data]])
*/
@Test
public void joinTest() {
final FrameworkConfig config = MyRelBuilder.config().build();
final RelBuilder builder = RelBuilder.create(config);
final RelNode left = builder
.scan("data")
.scan("data")
.join(JoinRelType.INNER, "Id")
.build();
final RelNode right = builder
.scan("data")
.scan("data")
.join(JoinRelType.INNER, "Name")
.build();
final RelNode result = builder
.push(left)
.push(right)
.join(JoinRelType.INNER, "Score")
.build();
System.out.println(RelOptUtil.toString(result));
}
/**
* LogicalFilter(condition=[>($1, 90)])
* LogicalProject(Name=[$1], Score=[$2])
* LogicalTableScan(table=[[csv, data]])
*/
@Test
public void projectWithFilterTest() {
final FrameworkConfig config = MyRelBuilder.config().build();
final RelBuilder builder = RelBuilder.create(config);
final RelNode node = builder
.scan("data")
.project(builder.field("Name"), builder.field("Score"))
.filter(builder.call(SqlStdOperatorTable.GREATER_THAN,
builder.field("Score"),
builder.literal(90)))
.build();
System.out.println(RelOptUtil.toString(node));
}
/**
* LogicalTableScan(table=[[csv, data]])
*/
@Test
public void scanTest() {
final FrameworkConfig config = MyRelBuilder.config().build();
final RelBuilder builder = RelBuilder.create(config);
final RelNode node = builder
.scan("data")
.build();
System.out.println(RelOptUtil.toString(node));
}
public static Frameworks.ConfigBuilder config() {
final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
return Frameworks.newConfigBuilder()
.parserConfig(SqlParser.Config.DEFAULT)
.defaultSchema(rootSchema.add("csv", new CsvSchema("data.csv")))
.traitDefs((List<RelTraitDef>) null);
}
}
CsvSchema
package cn.com.ptpress.cdm.ds.csv;
import org.apache.calcite.schema.Table;
import org.apache.calcite.schema.impl.AbstractSchema;
import org.apache.calcite.util.Source;
import org.apache.calcite.util.Sources;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
public class CsvSchema extends AbstractSchema {
private Map<String, Table> tableMap = new HashMap<>();
private String dataFiles;
public CsvSchema(String dataFile) {
this.dataFiles = dataFile;
}
@Override
protected Map<String, Table> getTableMap() {
//获取resources下的每隔csv文件,并为其创建CSV表结构
for (String dataFile : dataFiles.split(",")) {
URL url = ClassLoader.getSystemClassLoader().getResource(dataFile);
assert url != null;
Source source = Sources.of(url);
tableMap.put(dataFile.split("\\.")[0], new CsvTable(source));
}
return tableMap;
}
}
CsvTable
package cn.com.ptpress.cdm.ds.csv;
import org.apache.calcite.DataContext;
import org.apache.calcite.linq4j.AbstractEnumerable;
import org.apache.calcite.linq4j.Enumerable;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.schema.ScannableTable;
import org.apache.calcite.schema.impl.AbstractTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Source;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
public class CsvTable extends AbstractTable implements ScannableTable {
private Source source;
public CsvTable(Source source) {
this.source = source;
}
/**
* 获取字段类型
*/
@Override
public RelDataType getRowType(RelDataTypeFactory relDataTypeFactory) {
//保存字段和字段类型的映射
List<String> names = new LinkedList<>();
List<RelDataType> types = new LinkedList<>();
try (BufferedReader reader = new BufferedReader(new FileReader(source.file()))) {
String[] columnWithType = reader.readLine().split(" ");
for (String str : columnWithType) {
String name = str.split(":")[0];
String type = str.split(":")[1];
names.add(name);
types.add(relDataTypeFactory.createSqlType(SqlTypeName.get(type)));
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return relDataTypeFactory.createStructType(Pair.zip(names, types));
}
@Override
public Enumerable<Object[]> scan(DataContext dataContext) {
return new AbstractEnumerable<Object[]>() {
@Override
public Enumerator<Object[]> enumerator() {
return new CsvEnumerator<>(source);
}
};
}
}
CsvEnumerator
package cn.com.ptpress.cdm.ds.csv;
import org.apache.calcite.avatica.util.DateTimeUtils;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Source;
import org.apache.commons.lang3.time.FastDateFormat;
import java.io.BufferedReader;
import java.io.IOException;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;
public class CsvEnumerator<E> implements Enumerator<E> {
private E current;
private BufferedReader br;
private List<SqlTypeName> types;
CsvEnumerator(Source source) {
try {
this.br = new BufferedReader(source.reader());
String[] columnWithType = this.br.readLine().split(" ");
types = new ArrayList<>(columnWithType.length);
for (String str : columnWithType) {
String type = str.split(":")[1];
types.add(SqlTypeName.get(type));
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public E current() {
return current;
}
/**
* 判断是否有下一行,并更新current
*/
@Override
public boolean moveNext() {
try {
String line = br.readLine();
if (line == null || "".equals(line.trim())) {
return false;
}
final String[] values = line.split(",");
Object[] row = new Object[values.length];
for (int i = 0; i < values.length; i++) {
row[i] = convert(types.get(i), values[i]);
}
current = (E) row; // 如果是多列,这里要多个值
} catch (IOException e) {
e.printStackTrace();
return false;
}
return true;
}
@Override
public void reset() {
throw new UnsupportedOperationException("Error");
}
@Override
public void close() {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
private static final FastDateFormat TIME_FORMAT_DATE;
private static final FastDateFormat TIME_FORMAT_TIME;
private static final FastDateFormat TIME_FORMAT_TIMESTAMP;
static {
final TimeZone gmt = TimeZone.getTimeZone("GMT");
TIME_FORMAT_DATE = FastDateFormat.getInstance("yyyy-MM-dd", gmt);
TIME_FORMAT_TIME = FastDateFormat.getInstance("HH:mm:ss", gmt);
TIME_FORMAT_TIMESTAMP = FastDateFormat.getInstance("yyyy-MM-dd HH:mm:ss", gmt);
}
private Object convert(SqlTypeName fieldType, String string) {
if (fieldType == null) {
return string;
}
switch (fieldType) {
case BOOLEAN:
if (string.length() == 0) {
return null;
}
return Boolean.parseBoolean(string);
case TINYINT:
if (string.length() == 0) {
return null;
}
return Byte.parseByte(string);
case SMALLINT:
if (string.length() == 0) {
return null;
}
return Short.parseShort(string);
case INTEGER:
if (string.length() == 0) {
return null;
}
return Integer.parseInt(string);
case BIGINT:
if (string.length() == 0) {
return null;
}
return Long.parseLong(string);
case FLOAT:
if (string.length() == 0) {
return null;
}
return Float.parseFloat(string);
case DOUBLE:
if (string.length() == 0) {
return null;
}
return Double.parseDouble(string);
case DATE:
if (string.length() == 0) {
return null;
}
try {
Date date = TIME_FORMAT_DATE.parse(string);
return (int) (date.getTime() / DateTimeUtils.MILLIS_PER_DAY);
} catch (ParseException e) {
return null;
}
case TIME:
if (string.length() == 0) {
return null;
}
try {
Date date = TIME_FORMAT_TIME.parse(string);
return (int) date.getTime();
} catch (ParseException e) {
return null;
}
case TIMESTAMP:
if (string.length() == 0) {
return null;
}
try {
Date date = TIME_FORMAT_TIMESTAMP.parse(string);
return date.getTime();
} catch (ParseException e) {
return null;
}
case VARCHAR:
default:
return string;
}
}
}
resources/data.csv
Id:VARCHAR Name:VARCHAR Score:INTEGER
1,小明,90
2,小红,98
3,小亮,95