从 0 到 1 实现简单 ORM 框架【Java】
ORM 是什么?
ORM
指对象关系映射(Object Relational Mapping),在没有 ORM 框架 (MyBatis、Hibernate 等)操作数据库不仅需要自己写 SQL 语句,同时会有大量的重复步骤。而 ORM 能大大的简化这一过程.
开始动手
先创建一个简单的数据库表
CREATE TABLE student (
id` int NOT NULL COMMENT '学生ID',
nickname` varchar(32) NOT NULL COMMENT '姓名',
age tinyint UNSIGNED NOT NULL COMMENT '年龄',
PRIMARY KEY (`id`)
);
创建数据模型
package icu.twtool.entity;
/**
* 学生
* @author wen
* @since 2022-12-01
*/
public class Student {
/**
* 学生ID
*/
private Integer id;
/**
* 姓名
*/
private String nickname;
/**
* 年龄
*/
private Short age;
// 省略 getter/setter/toString
}
还需要准备什么?
名称映射工具类
Q:做什么?
A:将 Java 习惯的驼峰式命名与数据库习惯的下划线命名方式进行转化
package icu.twtool.orm.util;
import java.util.Arrays;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* 名称转换工具
* @author wen
* @since 2022-12-01
*/
public final class NameConvert {
private NameConvert() {}
/**
* 驼峰含首字母大写的正则表达式
*/
private final static Pattern CAMEL_NAME_PATTERN = Pattern.compile("([A-Z][^A-Z]*)");
/**
* 驼峰转下划线,首字母大写的也转为小写
* @param name 驼峰的名字
* @return 下划线的名字
*/
public static String camelToUnderline(String name) {
return CAMEL_NAME_PATTERN.matcher(name).replaceAll(match -> (match.start() == 0 ? "" : "_") + firstCharToLowerCase(match.group()));
}
/**
* 下划线转驼峰
* @param name 下划线的名字
* @return 驼峰的名字
*/
public static String underlineToCamel(String name) {
String camel = Arrays.stream(name.split("_")).map(NameConvert::firstCharToUpperCase).collect(Collectors.joining());
return name.isEmpty() ? camel : camel.replace(camel.charAt(0), Character.toLowerCase(camel.charAt(0)));
}
public static String firstCharToLowerCase(String word) {
if (word.isEmpty()) return word;
char firstChar = word.charAt(0);
return word.replace(firstChar, Character.toLowerCase(firstChar));
}
public static String firstCharToUpperCase(String word) {
if (word.isEmpty()) return word;
char firstChar = word.charAt(0);
return word.replace(firstChar, Character.toUpperCase(firstChar));
}
}
数据库连接工具类
写一个简单的数据库连接池类,这里的没有限制连接数上限,只会重复利用,同时会存在使用时认为 Connection 释放了,就不释放 Statement 的情况,需要优化
package icu.twtool.orm.util;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.locks.ReentrantLock;
/**
* 管理数据库连接
* @author wen
* @since 2022-12-01
*/
public final class DatabaseConnection {
private final static String url = "jdbc:mysql://localhost:3306/test";
private final static String username = "root";
private final static String password = "123456";
/**
* 保持活动的连接数
*/
private final static int MIN_COUNT = 5;
/**
* 当前可用的连接
*/
private final static Queue<Connection> availableConnections = new LinkedList<>();
private final static ReentrantLock lock = new ReentrantLock();
private DatabaseConnection() {}
/**
* 获取一个数据库连接对象
* @return 数据库连接对象
*/
public static Connection getConnection() {
lock.lock();
Connection result = null;
if (!availableConnections.isEmpty()) {
result = pollConnection();
}
try {
result = result != null ? result : createConnection();
} catch (SQLException e) {
throw new RuntimeException(e);
}
lock.unlock();
return result;
}
/**
* 从当前可用连接中获取可用的连接
* @return 如果存在可用连接就返回
*/
private static Connection pollConnection() {
Connection connection = availableConnections.poll();
try {
if (connection != null && connection.isValid(500)) return connection;
} catch (SQLException ignored) {
}
return null;
}
private static Connection createConnection() throws SQLException {
return (Connection) Proxy.newProxyInstance(DatabaseConnection.class.getClassLoader(), new Class[]{Connection.class}, new InvocationHandler() {
private final Connection connection = DriverManager.getConnection(url, username, password);
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if ("close".equals(method.getName())) {
lock.lock();
if (availableConnections.size() < MIN_COUNT) availableConnections.add((Connection) proxy);
else connection.close();
lock.unlock();
return null;
}
return method.invoke(connection, args);
}
});
}
}
定义一个函数式接口用来获取传入的方法的名称
package icu.twtool.orm.function;
import icu.twtool.orm.util.NameConvert;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.function.Function;
public interface OrmFunction<T, R> extends Function<T, R>, Serializable {
/**
* 获取字段名
* @return get方法获取的字段名
*/
default String getMethodName() {
try {
//WriteReplace改了好像会报异常
Method write = this.getClass().getDeclaredMethod("writeReplace");
write.setAccessible(true);
String methodName = ((SerializedLambda) write.invoke(this)).getImplMethodName();
if (methodName.startsWith("get")) methodName = methodName.substring(3);
return methodName;
} catch (Exception e) {
return null;
}
}
}
创建一个 ClassUtil
工具类获取一个数据模型类中的所有字段名
package icu.twtool.orm.util;
import icu.twtool.entity.Student;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Class 工具类
* @author wen
* @since 2022-12-01
*/
public final class ClassUtil {
public static List<String> getFieldNames(Class<?> clazz) {
List<String> result = new ArrayList<>();
if (clazz.equals(Object.class)) return result;
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
result.add(field.getName());
}
result.addAll(getFieldNames(clazz.getSuperclass()));
return result;
}
}
自定义一个运行时异常类
package icu.twtool.orm.exception;
/**
* 自定义一个运行时异常类
* @author wen
* @since 2022-12-01
*/
public class OrmException extends RuntimeException {
public OrmException(String message) {
super(message);
}
public OrmException(String message, Throwable cause) {
super(message, cause);
}
}
开始编写 Query 类
package icu.twtool.orm;
import icu.twtool.orm.exception.OrmException;
import icu.twtool.orm.function.OrmFunction;
import icu.twtool.orm.util.ClassUtil;
import icu.twtool.orm.util.DatabaseConnection;
import icu.twtool.orm.util.NameConvert;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* 执行 DQL
*
* @author wen
* @since 2022-12-01
*/
public final class Query<T> {
/**
* 最终返回的数据模型
*/
private final Class<T> dataModel;
/**
* 要查询的列,默认为数据模型的所有列
*/
private List<String> selectColumns;
/**
* 查询的表
*/
private String from;
/**
* Query 的构造器,私有
* @param clazz 数据模型的 Class 对象
*/
private Query(Class<T> clazz) {
dataModel = clazz;
selectColumns = ClassUtil.getFieldNames(clazz).stream().map(NameConvert::camelToUnderline).collect(Collectors.toList());
from = NameConvert.camelToUnderline(clazz.getSimpleName());
}
/**
* 构建 Query 对象
* @param clazz 数据模型 Class
* @param <T> 数据模型
*/
public static <T> Query<T> builder(Class<T> clazz) {
return new Query<>(clazz);
}
/**
* 要查询的列
* @param columns 要查询的列的 get 方法的 Lambda 方法
*/
@SafeVarargs
public final Query<T> select(OrmFunction<T, ?>... columns) {
this.selectColumns = Arrays.stream(columns).map(OrmFunction::getMethodName).map(NameConvert::camelToUnderline).collect(Collectors.toList());
return this;
}
/**
* 要查询的表名
*/
public Query<T> from(String from) {
this.from = from;
return this;
}
/**
* 通过 Class 对象实则要查询的表面
*/
public Query<T> from(Class<?> clazz) {
return from(NameConvert.camelToUnderline(clazz.getSimpleName()));
}
/**
* 构建预处理的 SQL 语句
*/
private String buildSQL() {
return "select " + String.join(",", selectColumns) + " from " + from;
}
/**
* 从 ResultSet 中获取一列数据,同时填充到要返回的数据中
*/
private void fillResultByColumn(T data, String column, ResultSet resultSet) {
try {
Field field = dataModel.getDeclaredField(NameConvert.underlineToCamel(column));
Class<?> type = field.getType();
try {
Object result = null;
if (Short.class.equals(type)) {
result = resultSet.getShort(column);
if (resultSet.wasNull()) result = null;
} else if (Integer.class.equals(type)) {
result = resultSet.getInt(column);
if (resultSet.wasNull()) result = null;
} else if (String.class.equals(type)) {
result = resultSet.getString(column);
}
if (result != null) {
field.setAccessible(true);
field.set(data, result);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new OrmException("填充字段输出错误:" + column, e);
}
} catch (NoSuchFieldException e) {
throw new OrmException("数据模型没有对应的字段:" + column, e);
}
}
/**
* 通过 ResultSet 构造一个需要返回的数据模型对象
*/
private T getObjByResultSet(ResultSet resultSet) {
try {
T result = dataModel.getConstructor().newInstance();
selectColumns.forEach(column -> fillResultByColumn(result, column, resultSet));
return result;
} catch (NoSuchMethodException | InvocationTargetException | InstantiationException |
IllegalAccessException e) {
throw new OrmException("数据模型没有无参构造器", e);
}
}
/**
* 查询一个数据模型列表
*/
public List<T> list() {
try (Connection connection = DatabaseConnection.getConnection()) {
PreparedStatement statement = connection.prepareStatement(buildSQL());
ResultSet resultSet = statement.executeQuery();
List<T> result = new ArrayList<>();
while (resultSet.next()) result.add(getObjByResultSet(resultSet));
return result;
} catch (Exception e) {
if (e instanceof RuntimeException) throw (RuntimeException) e;
throw new RuntimeException(e);
}
}
}
测试
public static void main(String[] args) {
System.out.println(Query.builder(Student.class).list());
System.out.println(Query.builder(Student.class).select(Student::getNickname).list());
}
-------------------
[Student{id=1, nickname='Wen', age=null}]
[Student{id=null, nickname='Wen', age=null}]
其它优化和设计点
Update
与 Insert
以及Delete
,同时 Query
还能添加 where
语句相关的 eq
等语句,或者其它子句