【mybatis专题】mybatis原理之动态代理生成的类注入spring容器

阅读过上一章的童鞋可以发现,mybatis只需要配置上@MapperScan注解就可以将扫描的包路径下的mapper实例化,底层使用的是jdk的动态代理,

下面我们模仿一个简单的@MyMapperScan实现类似的简单功能

首先我们模仿一个简单的SqlSession,该session只定义了一个根据传入的Mapper类获取mapper实例的接口

public interface MySession {
    Object getMapper(Class clazz);
}

然后我们使用jdk动态代理来实现Mysession接口

package org.example.my;

import java.lang.reflect.Proxy;

public class MySessionDefault implements MySession{

    @Override
    public Object getMapper(Class mapperInterface) {

        return Proxy.newProxyInstance(mapperInterface.getClassLoader(),
                                        new Class[]{mapperInterface},
                                        new MyInvocationHandler());
    }
}

第三个参数需要传入一个实现了InvocationHandler接口的实例,我们模仿了一个简单的

package org.example.my;

import com.alibaba.fastjson.JSON;
import org.apache.ibatis.annotations.Select;
import org.example.entity.Blog;
import org.example.util.JdbcUtil;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.List;

public class MyInvocationHandler implements InvocationHandler {

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        System.out.println("conn db");
        System.out.println("execute sql");
        //此处省略从接口的参数中获取值,拼装进sql的占位符中...
        String sql=method.getDeclaredAnnotation(Select.class).value()[0];
        List<Blog> blogList = JdbcUtil.executeQuery(sql, null);
        return blogList;
    }

}

上面的MyInvocationHandler只是简单实现了一个查询数据库的方法,下面是jdbc的简单工具类:

package org.example.util;

import org.example.entity.Blog;

import java.io.IOException;
import java.io.InputStream;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

public class JdbcUtil {
    //1、加载驱动
    //2、建立连接
    //3、准备prepareStatement预编译sql
    //4、执行sql的增删改查
    //5、处理返回结果
    //6、释放资源

    private static String driverClass = null;
    private static String url = null;
    private static String username = null;
    private static String password = null;

    static {
        //获取数据库连接配置文件中的相关信息
        Properties properties = new Properties();
        InputStream inputStream = JdbcUtil.class.getClassLoader().getResourceAsStream("db.properties");
        try {
            properties.load(inputStream);
            driverClass = properties.getProperty("datasource.driverClass");
            url = properties.getProperty("datasource.url");
            username = properties.getProperty("datasource.username");
            password = properties.getProperty("datasource.password");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static Connection getConnection() throws ClassNotFoundException, SQLException {
        Class.forName(driverClass);
        Connection connection = DriverManager.getConnection(url,username,password);
        return connection;
    }

    public static void close(Connection connection,PreparedStatement preparedStatement,ResultSet resultSet){
        if (resultSet!=null){
            try {
                resultSet.close();
            } catch (SQLException throwables) {
                throwables.printStackTrace();
            }
        }
        if (preparedStatement!=null){
            try {
                preparedStatement.close();
            } catch (SQLException throwables) {
                throwables.printStackTrace();
            }
        }
        if(connection!=null){
            try {
                connection.close();
            } catch (SQLException throwables) {
                throwables.printStackTrace();
            }
        }
    }

    public static List<Blog> executeQuery(String sql, Object[] params) throws SQLException, ClassNotFoundException {
        Connection connection = null;
        PreparedStatement preparedStatement = null;
        ResultSet resultSet = null;
        try {
            connection = getConnection();
            preparedStatement = connection.prepareStatement(sql);
            if(params!=null){
                for (int i = 0; i < params.length; i++) {
                    preparedStatement.setObject(i,params[i]);
                }
            }
            resultSet = preparedStatement.executeQuery();
            List<Blog> blogList = new ArrayList<>();
            while (resultSet.next()){
                Blog blog = new Blog();
                int blogId = resultSet.getInt("blog_id");
                String title = resultSet.getString("title");
                String content = resultSet.getString("content");
                String author = resultSet.getString("author");
                java.util.Date createTime = (java.util.Date)resultSet.getObject("create_time");
                java.util.Date updateTime = (java.util.Date) resultSet.getObject("update_time");
                blog.setBlogId(blogId);
                blog.setContent(content);
                blog.setAuthor(author);
                blog.setCreateTime(createTime);
                blog.setUpdateTime(updateTime);
                blogList.add(blog);
            }
            return blogList;
        } finally {
            close(connection,preparedStatement,resultSet);
        }
    }

    public static boolean execute(String sql,Object[] params) throws SQLException, ClassNotFoundException {
        Connection connection = null;
        PreparedStatement preparedStatement = null;
        ResultSet resultSet = null;
        try {
            connection = getConnection();
            preparedStatement = connection.prepareStatement(sql);
            if (params!=null){
                for (int i = 0; i < params.length; i++) {
                    preparedStatement.setObject(i,params[i]);
                }
            }
            boolean res = preparedStatement.execute();
            return res;
        } finally {
            close(connection,preparedStatement,null);
        }
    }
}

下面是数据库连接配置文件db.properties

datasource.driverClass = com.mysql.jdbc.Driver
datasource.url = jdbc:mysql://localhost:3306/my_test?useSSL=false
datasource.username = root
datasource.password = root

然后我们模仿一个FactoryBean,MyFactoryBean根据我们传入的参数可以生成对应的实例,FactoryBean可以控制实例创建的过程

package org.example.my;

import org.springframework.beans.factory.FactoryBean;

public class MyFactoryBean implements FactoryBean {
    private Class mapperInterface;

    //由于需要在MyImportBeanDefinitionRegistrar中给MyFactory设置mapperInterface,所以必须有set方法
    public void setMapperInterface(Class mapperInterface) {
        this.mapperInterface = mapperInterface;
    }

    @Override
    public Object getObject() throws Exception {
        return new MySessionDefault().getMapper(mapperInterface);
    }

    @Override
    public Class<?> getObjectType() {
        return mapperInterface;
    }

    @Override
    public boolean isSingleton() {
        return true;
    }
}

下面是比较关键的地方了,就是将MyFactoryBean产生的对象实例注入到spring容器中

package org.example.my;

import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.type.AnnotationMetadata;

public class MyImportBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar {

    @Override
    public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
        BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(MyFactoryBean.class);
        AbstractBeanDefinition beanDefinition = beanDefinitionBuilder.getBeanDefinition();
        beanDefinition.getPropertyValues().add("mapperInterface","org.example.mapper.BlogMapper");
        registry.registerBeanDefinition("blogMapper",beanDefinition);
    }

}

然后我们模仿@MapperScan写一个自定义注解@MyMapperScan,

@Retention(RetentionPolicy.RUNTIME)
@Import(MyImportBeanDefinitionRegistrar.class)
public @interface MyMapperScan {

}

下面是简单的使用:

public interface BlogService {

    List<Blog> queryAllBlog();

}
package org.example.service.impl;

import org.example.entity.Blog;
import org.example.mapper.BlogMapper;
import org.example.service.BlogService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.List;

@Service
public class BlogServiceImpl implements BlogService {

    @Autowired
    private BlogMapper blogMapper;

    @Override
    public List<Blog> queryAllBlog() {
        List<Blog> blogs = blogMapper.queryAllBlog();
        System.out.println(blogs);
        return blogs;
    }

}
package org.example;

import org.example.entity.Blog;
import org.example.my.MyFactoryBean;
import org.example.my.MyMapperScan;
import org.example.service.BlogService;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.context.annotation.*;

import java.util.List;

@ComponentScan(basePackages = {"org.example"})
//@MapperScan(basePackages = {"org.example.mapper"})
@MyMapperScan
public class App{

    public static void main( String[] args )
    {
        AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(App.class);
        System.out.println("----BlogMapper-----:"+applicationContext.getBean("blogMapper").getClass());
        BlogService blogService = applicationContext.getBean(BlogService.class);
        System.out.println("blogService testBlogMapper sql:"+blogService.queryAllBlog());
    }
}

下面是测试结果,可以发现mapper已经被成功Autowire(自动注入)BlogService的实例中,且能调用查询出数据库中的数据:

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值