017:Mybatis深入源码分析之高仿Mybatis框架
1 高仿MyBatis框架代码效果演示
课程内容:
1、MyBatis执行原理分析
2、纯手写Configuration
3、纯手写SqlSessionFactoryBuilder
4、纯手写MapperProxy
2 简答回顾MyBatis源码分析流程
3 高仿SqlSessionFactoryBuilder
my_config.properties
###springboot
driver=com.mysql.jdbc.Driver
url=jdbc:mysql://localhost:3306/test
username=root
password=root
mappers=com.mayikt.mappers
UserEntity
@Data
public class UserEntity {
private String userName;
private Integer userId;
public UserEntity() {
}
public UserEntity(String userName, Integer userId) {
this.userName = userName;
this.userId = userId;
}
@Override
public String toString() {
return "UserEntity{" +
"userName='" + userName + '\'' +
", userId=" + userId +
'}';
}
}
UserMapper
public interface UserMapper {
UserEntity selectUser();
}
BaseBuilder
public class BaseBuilder {
protected Configuration configuration;
}
Configuration
@Data
public class Configuration {
/**
* 数据源配置
*/
private DataSource dataSource;
/**
* 配置扫包范围
*/
private String mappers;
/**
* 配注册好的mapper接口
*/
private final Map<Class<?>, MapperProxyFactory<?>> knownMappers = new HashMap<Class<?>, MapperProxyFactory<?>>();
public <T> void addMapper(Class<T> type) {
knownMappers.put(type, new MapperProxyFactory<>(type));
}
public <T> T getMapper(Class<T> type) {
MapperProxyFactory<?> mapperProxyFactory = knownMappers.get(type);
return (T) mapperProxyFactory;
}
}
DataSource
@Data
public class DataSource {
private String driver;
private String url;
private String username;
private String password;
public DataSource(String driver, String url, String username, String password) {
this.driver = driver;
this.url = url;
this.username = username;
this.password = password;
}
public DataSource() {
}
}
XMLConfigBuilder
public class XMLConfigBuilder extends BaseBuilder {
private String propertiesName;
private PropertiesUtil propertiesUtil;
public XMLConfigBuilder(String propertiesName) {
this.propertiesName = propertiesName;
this.configuration = new Configuration();
this.propertiesUtil = new PropertiesUtil(propertiesName);
}
/**
* 解析配置文件
*/
public Configuration parse() {
// 1.解析配置文件,获取datasource
parseDataSource();
// 2.解析mappers注册
parseMapperse();
return this.configuration;
}
public void parseDataSource() {
String driver = propertiesUtil.readProperty("driver");
String url = propertiesUtil.readProperty("url");
String username = propertiesUtil.readProperty("username");
String password = propertiesUtil.readProperty("password");
DataSource dataSource = new DataSource(driver, url, username, password);
this.configuration.setDataSource(dataSource);
}
public void parseMapperse() {
String mappers = propertiesUtil.readProperty("mappers");
this.configuration.setMappers(mappers);
// 3.使用java反射机制获取该包下所有的mapper接口
List<Class<?>> classesByPackageName = ClassUtil.getClassSet(mappers);
for (int i = 0; i < classesByPackageName.size(); i++) {
Class classInfo = classesByPackageName.get(i);
this.configuration.addMapper(classInfo);
}
}
}
SqlSessionFactoryBuilder
public class SqlSessionFactoryBuilder {
/**
* 解析my_config.properties文件
*/
public SqlSessionFactory build(String propertiesName){
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(propertiesName);
return build(xmlConfigBuilder.parse());
}
public SqlSessionFactory build(Configuration configuration){
return new DefaultSqlSessionFactory(configuration);
}
}
4 高仿DefaultSqlSessionFactory
SqlSessionFactory
public interface SqlSessionFactory {
public SqlSession openSqlSession();
}
DefaultSqlSessionFactory
public class DefaultSqlSessionFactory implements SqlSessionFactory{
private final Configuration configuration;
public DefaultSqlSessionFactory(Configuration configuration){
this.configuration = configuration;
}
@Override
public SqlSession openSqlSession() {
// 使用默认的SqlSession
return new DefaultSqlSession(configuration);
}
}
5 高仿SqlSession会话接口
SqlSession
public interface SqlSession {
<T> T getMapper(Class<T> type) throws Exception;
}
DefaultSqlSession
public class DefaultSqlSession implements SqlSession {
private final Configuration configuration;
public DefaultSqlSession(Configuration configuration) {
this.configuration = configuration;
}
@Override
public <T> T getMapper(Class<T> type) throws Exception {
// 查询configuration是否有注册mapper接口
MapperProxyFactory mapperProxyFactory = (MapperProxyFactory) configuration.getMapper(type);
if (mapperProxyFactory == null) {
throw new Exception("Type" + type + "is not known to the MapperProxyFactory.");
}
return (T) mapperProxyFactory.newInstance(this);
}
}
6 使用Java反射机制注册Mapper接口
MapperProxyFactory
public class MapperProxyFactory<T> {
private final Class<T> mapperInterface;
public MapperProxyFactory(Class<T> mapperInterface) {
this.mapperInterface = mapperInterface;
}
public T newInstance(SqlSession sqlSession) {
final MapperProxy<T> mapperProxy = new MapperProxy<T>(sqlSession, mapperInterface);
return newInstance(mapperProxy);
}
/**
* 使用jdk动态代理创建代理对象
* @param mapperProxy
* @return
*/
public T newInstance(MapperProxy<T> mapperProxy) {
return (T) Proxy.newProxyInstance(mapperProxy.getClass().getClassLoader(), new Class[]{mapperInterface}, mapperProxy);
}
}
7 使用代理模式高仿生成Mapper接口
MapperProxy
public class MapperProxy<T> implements InvocationHandler {
public MapperProxy(SqlSession sqlSession, Class<T> mapperInterface) {
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
return new UserEntity("mayikt", 666);
}
}
TestMyBatis
public class TestMyBatis {
public static void main(String[] args) throws Exception {
// 1.获取默认sqlSessionFactory
SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build("my_config.properties");
SqlSession sqlSession = sqlSessionFactory.openSqlSession();
// 2.生成userMapper代理类
UserMapper userMapper = sqlSession.getMapper(UserMapper.class);
// 3.执行MapperProxy的invoke方法
UserEntity userEntity = userMapper.selectUser();
System.out.println(userEntity);
}
}
测试结果:
工具类
PropertiesUtil
import java.io.InputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Properties;
/**
* @author
*/
public class PropertiesUtil {
private String propertiesName = "";
public PropertiesUtil() {
}
public PropertiesUtil(String fileName) {
this.propertiesName = fileName;
}
/**
* 按key获取值
*
* @param key
* @return
*/
public String readProperty(String key) {
String value = "";
InputStream is = null;
try {
is = PropertiesUtil.class.getClassLoader().getResourceAsStream(propertiesName);
Properties p = new Properties();
p.load(is);
value = p.getProperty(key);
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
return value;
}
/**
* 获取整个配置信息
*
* @return
*/
public Properties getProperties() {
Properties p = new Properties();
InputStream is = null;
try {
is = PropertiesUtil.class.getClassLoader().getResourceAsStream(propertiesName);
p.load(is);
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
return p;
}
/**
* key-value写入配置文件
*
* @param key
* @param value
*/
public void writeProperty(String key, String value) {
InputStream is = null;
OutputStream os = null;
Properties p = new Properties();
try {
is = new FileInputStream(propertiesName);
// is = PropertiesUtil.class.getClassLoader().getResourceAsStream(propertiesName);
p.load(is);
// os = new FileOutputStream(PropertiesUtil.class.getClassLoader().getResource(propertiesName).getFile());
os = new FileOutputStream(propertiesName);
p.setProperty(key, value);
p.store(os, key);
os.flush();
os.close();
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
if (null != is)
is.close();
if (null != os)
os.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
public static void main(String[] args) {
// sysConfig.properties(配置文件)
PropertiesUtil p = new PropertiesUtil("sysConfig.properties");
System.out.println(p.getProperties().get("db.url"));
System.out.println(p.readProperty("db.url"));
PropertiesUtil q = new PropertiesUtil("resources/sysConfig.properties");
q.writeProperty("myUtils", "wang");
System.exit(0);
}
}
ClassUtil
import java.io.File;
import java.io.FileFilter;
import java.net.JarURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
/**
* 类操作工具类
*
*/
public final class ClassUtil {
/**
* 获取类加载器
*/
public static ClassLoader getClassLoader() {
return Thread.currentThread().getContextClassLoader();
}
/**
* 加载类
*/
public static Class<?> loadClass(String className, boolean isInitialized) {
Class<?> cls;
try {
cls = Class.forName(className, isInitialized, getClassLoader());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
return cls;
}
/**
* 加载类(默认将初始化类)
*/
public static Class<?> loadClass(String className) {
return loadClass(className, true);
}
/**
* 获取指定包名下的所有类
*/
public static List<Class<?>> getClassSet(String packageName) {
List<Class<?>> classSet = new ArrayList<Class<?>>();
try {
Enumeration<URL> urls = getClassLoader().getResources(packageName.replace(".", "/"));
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
if (url != null) {
String protocol = url.getProtocol();
if (protocol.equals("file")) {
String packagePath = url.getPath().replaceAll("%20", " ");
addClass(classSet, packagePath, packageName);
} else if (protocol.equals("jar")) {
JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
if (jarURLConnection != null) {
JarFile jarFile = jarURLConnection.getJarFile();
if (jarFile != null) {
Enumeration<JarEntry> jarEntries = jarFile.entries();
while (jarEntries.hasMoreElements()) {
JarEntry jarEntry = jarEntries.nextElement();
String jarEntryName = jarEntry.getName();
if (jarEntryName.endsWith(".class")) {
String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
doAddClass(classSet, className);
}
}
}
}
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return classSet;
}
private static void addClass(List<Class<?>> classSet, String packagePath, String packageName) {
File[] files = new File(packagePath).listFiles(new FileFilter() {
@Override
public boolean accept(File file) {
return (file.isFile() && file.getName().endsWith(".class")) || file.isDirectory();
}
});
for (File file : files) {
String fileName = file.getName();
if (file.isFile()) {
String className = fileName.substring(0, fileName.lastIndexOf("."));
if (StringUtil.isNotEmpty(packageName)) {
className = packageName + "." + className;
}
doAddClass(classSet, className);
} else {
String subPackagePath = fileName;
if (StringUtil.isNotEmpty(packagePath)) {
subPackagePath = packagePath + "/" + subPackagePath;
}
String subPackageName = fileName;
if (StringUtil.isNotEmpty(packageName)) {
subPackageName = packageName + "." + subPackageName;
}
addClass(classSet, subPackagePath, subPackageName);
}
}
}
private static void doAddClass(List<Class<?>> classSet, String className) {
Class<?> cls = loadClass(className, false);
classSet.add(cls);
}
}
StringUtil
import org.apache.commons.lang3.StringUtils;
/**
* 字符串工具类
*
*/
public final class StringUtil {
/**
* 字符串分隔符
*/
public static final String SEPARATOR = String.valueOf((char) 29);
/**
* 判断字符串是否为空
*/
public static boolean isEmpty(String str) {
if (str != null) {
str = str.trim();
}
return StringUtils.isEmpty(str);
}
/**
* 判断字符串是否非空
*/
public static boolean isNotEmpty(String str) {
return !isEmpty(str);
}
/**
* 分割固定格式的字符串
*/
public static String[] splitString(String str, String separator) {
return StringUtils.splitByWholeSeparator(str, separator);
}
}