阅读过上一章的童鞋可以发现,mybatis只需要配置上@MapperScan注解就可以将扫描的包路径下的mapper实例化,底层使用的是jdk的动态代理,
下面我们模仿一个简单的@MyMapperScan实现类似的简单功能
首先我们模仿一个简单的SqlSession,该session只定义了一个根据传入的Mapper类获取mapper实例的接口
public interface MySession {
Object getMapper(Class clazz);
}
然后我们使用jdk动态代理来实现Mysession接口
package org.example.my;
import java.lang.reflect.Proxy;
public class MySessionDefault implements MySession{
@Override
public Object getMapper(Class mapperInterface) {
return Proxy.newProxyInstance(mapperInterface.getClassLoader(),
new Class[]{mapperInterface},
new MyInvocationHandler());
}
}
第三个参数需要传入一个实现了InvocationHandler接口的实例,我们模仿了一个简单的
package org.example.my;
import com.alibaba.fastjson.JSON;
import org.apache.ibatis.annotations.Select;
import org.example.entity.Blog;
import org.example.util.JdbcUtil;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.List;
public class MyInvocationHandler implements InvocationHandler {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
System.out.println("conn db");
System.out.println("execute sql");
//此处省略从接口的参数中获取值,拼装进sql的占位符中...
String sql=method.getDeclaredAnnotation(Select.class).value()[0];
List<Blog> blogList = JdbcUtil.executeQuery(sql, null);
return blogList;
}
}
上面的MyInvocationHandler只是简单实现了一个查询数据库的方法,下面是jdbc的简单工具类:
package org.example.util;
import org.example.entity.Blog;
import java.io.IOException;
import java.io.InputStream;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
public class JdbcUtil {
//1、加载驱动
//2、建立连接
//3、准备prepareStatement预编译sql
//4、执行sql的增删改查
//5、处理返回结果
//6、释放资源
private static String driverClass = null;
private static String url = null;
private static String username = null;
private static String password = null;
static {
//获取数据库连接配置文件中的相关信息
Properties properties = new Properties();
InputStream inputStream = JdbcUtil.class.getClassLoader().getResourceAsStream("db.properties");
try {
properties.load(inputStream);
driverClass = properties.getProperty("datasource.driverClass");
url = properties.getProperty("datasource.url");
username = properties.getProperty("datasource.username");
password = properties.getProperty("datasource.password");
} catch (IOException e) {
e.printStackTrace();
}
}
public static Connection getConnection() throws ClassNotFoundException, SQLException {
Class.forName(driverClass);
Connection connection = DriverManager.getConnection(url,username,password);
return connection;
}
public static void close(Connection connection,PreparedStatement preparedStatement,ResultSet resultSet){
if (resultSet!=null){
try {
resultSet.close();
} catch (SQLException throwables) {
throwables.printStackTrace();
}
}
if (preparedStatement!=null){
try {
preparedStatement.close();
} catch (SQLException throwables) {
throwables.printStackTrace();
}
}
if(connection!=null){
try {
connection.close();
} catch (SQLException throwables) {
throwables.printStackTrace();
}
}
}
public static List<Blog> executeQuery(String sql, Object[] params) throws SQLException, ClassNotFoundException {
Connection connection = null;
PreparedStatement preparedStatement = null;
ResultSet resultSet = null;
try {
connection = getConnection();
preparedStatement = connection.prepareStatement(sql);
if(params!=null){
for (int i = 0; i < params.length; i++) {
preparedStatement.setObject(i,params[i]);
}
}
resultSet = preparedStatement.executeQuery();
List<Blog> blogList = new ArrayList<>();
while (resultSet.next()){
Blog blog = new Blog();
int blogId = resultSet.getInt("blog_id");
String title = resultSet.getString("title");
String content = resultSet.getString("content");
String author = resultSet.getString("author");
java.util.Date createTime = (java.util.Date)resultSet.getObject("create_time");
java.util.Date updateTime = (java.util.Date) resultSet.getObject("update_time");
blog.setBlogId(blogId);
blog.setContent(content);
blog.setAuthor(author);
blog.setCreateTime(createTime);
blog.setUpdateTime(updateTime);
blogList.add(blog);
}
return blogList;
} finally {
close(connection,preparedStatement,resultSet);
}
}
public static boolean execute(String sql,Object[] params) throws SQLException, ClassNotFoundException {
Connection connection = null;
PreparedStatement preparedStatement = null;
ResultSet resultSet = null;
try {
connection = getConnection();
preparedStatement = connection.prepareStatement(sql);
if (params!=null){
for (int i = 0; i < params.length; i++) {
preparedStatement.setObject(i,params[i]);
}
}
boolean res = preparedStatement.execute();
return res;
} finally {
close(connection,preparedStatement,null);
}
}
}
下面是数据库连接配置文件db.properties
datasource.driverClass = com.mysql.jdbc.Driver
datasource.url = jdbc:mysql://localhost:3306/my_test?useSSL=false
datasource.username = root
datasource.password = root
然后我们模仿一个FactoryBean,MyFactoryBean根据我们传入的参数可以生成对应的实例,FactoryBean可以控制实例创建的过程
package org.example.my;
import org.springframework.beans.factory.FactoryBean;
public class MyFactoryBean implements FactoryBean {
private Class mapperInterface;
//由于需要在MyImportBeanDefinitionRegistrar中给MyFactory设置mapperInterface,所以必须有set方法
public void setMapperInterface(Class mapperInterface) {
this.mapperInterface = mapperInterface;
}
@Override
public Object getObject() throws Exception {
return new MySessionDefault().getMapper(mapperInterface);
}
@Override
public Class<?> getObjectType() {
return mapperInterface;
}
@Override
public boolean isSingleton() {
return true;
}
}
下面是比较关键的地方了,就是将MyFactoryBean产生的对象实例注入到spring容器中
package org.example.my;
import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.type.AnnotationMetadata;
public class MyImportBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar {
@Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(MyFactoryBean.class);
AbstractBeanDefinition beanDefinition = beanDefinitionBuilder.getBeanDefinition();
beanDefinition.getPropertyValues().add("mapperInterface","org.example.mapper.BlogMapper");
registry.registerBeanDefinition("blogMapper",beanDefinition);
}
}
然后我们模仿@MapperScan写一个自定义注解@MyMapperScan,
@Retention(RetentionPolicy.RUNTIME)
@Import(MyImportBeanDefinitionRegistrar.class)
public @interface MyMapperScan {
}
下面是简单的使用:
public interface BlogService {
List<Blog> queryAllBlog();
}
package org.example.service.impl;
import org.example.entity.Blog;
import org.example.mapper.BlogMapper;
import org.example.service.BlogService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
public class BlogServiceImpl implements BlogService {
@Autowired
private BlogMapper blogMapper;
@Override
public List<Blog> queryAllBlog() {
List<Blog> blogs = blogMapper.queryAllBlog();
System.out.println(blogs);
return blogs;
}
}
package org.example;
import org.example.entity.Blog;
import org.example.my.MyFactoryBean;
import org.example.my.MyMapperScan;
import org.example.service.BlogService;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.context.annotation.*;
import java.util.List;
@ComponentScan(basePackages = {"org.example"})
//@MapperScan(basePackages = {"org.example.mapper"})
@MyMapperScan
public class App{
public static void main( String[] args )
{
AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(App.class);
System.out.println("----BlogMapper-----:"+applicationContext.getBean("blogMapper").getClass());
BlogService blogService = applicationContext.getBean(BlogService.class);
System.out.println("blogService testBlogMapper sql:"+blogService.queryAllBlog());
}
}
下面是测试结果,可以发现mapper已经被成功Autowire(自动注入)BlogService的实例中,且能调用查询出数据库中的数据: