spring入门——如何实现包扫描

spring中的包扫描

在spring中有两种方式可以实现包扫描

  1. 传统的xml配置方式

    <!--配置扫描com.example.spring.beans下的所有bean-->
    <context:component-scan base-package="com.example.spring.beans"/>
    
  2. 基于注解的方式

    @Configuration
    @ComponentScan("com.example.spring.beans4")
    public class ComponentScanConfig {
    }
    

如何实现呢?

今天换个思路,如果这个需求要我们来实现,那我们如何做呢?我们都知道,一个bean如果想被spring管理起来,那么一定得把BeanDefinition交给BeanDefinitionRegistry。那么,BeanDefinition从哪儿来呢?BeanDefinition应该从我们的扫描结果中来。那么应该如何扫描呢?我觉得应该分为以下2步:

  1. 使用classLoader.getResources(resourceName);查找到所有的resource。
  2. 遍历resource,根据protocol的不同进行不同的查找。

自己实现

根据上面的想法,我自己实现了一个简单的ClassScanner。实现需求的同时,基于“开闭原则”对接口进行封装。

源码

/**
 * 一个简单的class查找工具,目前仅支持jar包查找和本地查找
 */
public class SimpleClassScan {

    private final Set<Class<?>> classSet;
    private final Map<String, ProtocolHandler> handlerMap;

    public SimpleClassScan() {
        classSet = new HashSet<>();
        handlerMap = new HashMap<>();
        //注册一个文件扫描器
        FileProtocolHandler fileProtocolHandler = new FileProtocolHandler();
        //注册一个jar包扫描器
        JarProtocolHandler jarProtocolHandler = new JarProtocolHandler();
        handlerMap.put(fileProtocolHandler.handleProtocol(), fileProtocolHandler);
        handlerMap.put(jarProtocolHandler.handleProtocol(), jarProtocolHandler);
    }

    public Set<Class<?>> scan(String... basePackages) {
        ClassLoader classLoader = this.getClass().getClassLoader();
        for (String basePackage : basePackages) {
            //将com.aa.bb 替换成 com/aa/bb
            String resourceName = basePackage.replace('.', '/') + "/";
            Enumeration<URL> resources = null;
            try {
                //通过classLoader获取所有的resources
                resources = classLoader.getResources(resourceName);
            } catch (IOException e) {
                e.printStackTrace();
            }
            if (resources == null) {
                continue;
            }
            while (resources.hasMoreElements()) {
                URL url = resources.nextElement();
                String protocol = url.getProtocol();
                //根据url中protocol类型查找适用的解析器
                ProtocolHandler protocolHandler = handlerMap.get(protocol);
                if (protocolHandler == null) {
                    throw new RuntimeException("need support protocol [" + protocol + "]");
                }
                protocolHandler.handle(basePackage, url);
            }
        }
        return classSet;
    }

    /**
     * 将class添加到结果中
     * @param classFullName 形如com.aa.bb.cc.Test.class的字符串
     */
    private void addResult(String classFullName) {
        Class<?> aClass = null;
        try {
            aClass = Class.forName(classFullName.substring(0, classFullName.length() - 6));
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        if (aClass != null) {
            classSet.add(aClass);
        }
    }

    /**
     * 检查一个文件名是否是class文件名
     * @param fileName 文件名
     * @return
     */
    private boolean checkIsNotClass(String fileName) {
        //只要class类型的文件
        boolean isClass = fileName.endsWith(".class");
        if (!isClass) {
            return true;
        }
        //排除内部类
        return fileName.indexOf('$') != -1;
    }

    public Set<Class<?>> getClassSet() {
        return classSet;
    }

    /**
     * 协议处理器
     */
    private interface ProtocolHandler {
        /**
         * 适配的协议
         *
         * @return
         */
        String handleProtocol();

        /**
         * 处理url,最后需要调用{@link #addResult(String)}将结果存储到result中
         *
         * @param url
         */
        void handle(String basePackage, URL url);
    }

    /**
     * jar包解析器
     */
    private class JarProtocolHandler implements ProtocolHandler {

        @Override
        public String handleProtocol() {
            return "jar";
        }

        @Override
        public void handle(String basePackage, URL url) {
            try {
                String resourceName = basePackage.replace('.', '/') + "/";
                JarURLConnection conn = (JarURLConnection) url.openConnection();
                JarFile jarFile = conn.getJarFile();
                Enumeration<JarEntry> entries = jarFile.entries();
                while (entries.hasMoreElements()) {
                    //遍历jar包中的所有项
                    JarEntry jarEntry = entries.nextElement();
                    String entryName = jarEntry.getName();
                    if (!entryName.startsWith(resourceName)) {
                        continue;
                    }
                    if (checkIsNotClass(entryName)) {
                        continue;
                    }
                    String classNameFullName = entryName.replace('/', '.');
                    addResult(classNameFullName);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 文件解析器
     */
    private class FileProtocolHandler implements ProtocolHandler {

        @Override
        public String handleProtocol() {
            return "file";
        }

        @Override
        public void handle(String basePackage, URL url) {
            File rootFile = new File(url.getFile());
            findClass(rootFile, File.separator + basePackage.replace('.', File.separatorChar) + File.separator);
        }

        /**
         * 递归的方式查找class文件
         * @param rootFile 当前文件
         * @param subFilePath 子路径
         */
        private void findClass(File rootFile, String subFilePath) {
            if (rootFile == null) {
                return;
            }
            //如果是文件夹
            if (rootFile.isDirectory()) {
                File[] files = rootFile.listFiles();
                if (files == null) {
                    return;
                }
                for (File file : files) {
                    findClass(file, subFilePath);
                }
            }
            String fileName = rootFile.getName();
            if (checkIsNotClass(fileName)) {
                return;
            }
            String path = rootFile.getPath();
            int i = path.indexOf(subFilePath);
            String subPath = path.substring(i + 1);
            String fullClassPath = subPath.replace(File.separatorChar, '.');
            addResult(fullClassPath);
        }
    }
}

验证

我们看下实际效果:

  1. 扫描本地包中的class
    在这里插入图片描述
  2. 扫描依赖包中的class
    在这里插入图片描述

反思与总结

  1. 通过classLoader.getResources("resourceName")可以获取到resources,其中resourceName必须是com/xxx/rrr的形式

  2. 文件分隔符在Windows和linux上是不同的,在Windows上是\,在linux上是/。我们可以通过File.separatorChar来获取当前操作系统中的文件分隔符

  3. url是通过protocol来区分的

  4. 扫描文件系统,可以使用File对象+递归的方式实现

  5. 扫描jar时,需要openConnection

    JarURLConnection conn = (JarURLConnection) url.openConnection();
    JarFile jarFile = conn.getJarFile();
    Enumeration<JarEntry> entries = jarFile.entries();
    while (entries.hasMoreElements()) {
        //遍历jar包中的所有项
        JarEntry jarEntry = entries.nextElement();
        String entryName = jarEntry.getName();
        //TODO xxx
    }
    

spring的实现

那么spring是如何实现的呢?

spring的源码

在spring中,spring使用ClassPathBeanDefinitionScanner来实现包扫描,这个东西用起来非常方便,如果我想将com.example.spring.beans3下所有带有 @SkylineComponent注解的类注册为bean,那么可以这么写:

ClassPathBeanDefinitionScanner scanner = new ClassPathBeanDefinitionScanner(registry, false);
scanner.addIncludeFilter(new AnnotationTypeFilter(SkylineComponent.class));
scanner.scan("com.example.spring.beans3");

当然,这段代码一定要写在实现了BeanDefinitionRegistryPostProcessor的bean中。
接下来我们看下ClassPathBeanDefinitionScanner的构造器

public ClassPathBeanDefinitionScanner(BeanDefinitionRegistry registry, boolean useDefaultFilters,Environment environment, @Nullable ResourceLoader resourceLoader) {

	//最最重要的,BeanDefinitionRegistry不能空
	Assert.notNull(registry, "BeanDefinitionRegistry must not be null");
	this.registry = registry;

	//是否使用默认的过滤器,如果使用默认的过滤器,那么仅扫描@Component注解
	if (useDefaultFilters) {
		registerDefaultFilters();
	}
	//设置环境参数
	setEnvironment(environment);
	//设置资源加载器
	setResourceLoader(resourceLoader);
}

org.springframework.context.annotation.ClassPathBeanDefinitionScanner#scan就是扫描方法的入口,实际的扫描逻辑是写在doScan方法中的,我们看下这个方法:

protected Set<BeanDefinitionHolder> doScan(String... basePackages) {
	//首先,basePackages不能是空的
	Assert.notEmpty(basePackages, "At least one base package must be specified");
	Set<BeanDefinitionHolder> beanDefinitions = new LinkedHashSet<>();
	//遍历所有要扫描的包
	for (String basePackage : basePackages) {
		//获取到待选的BeanDefinition
		Set<BeanDefinition> candidates = findCandidateComponents(basePackage);
		//遍历待选的BeanDefinition
		for (BeanDefinition candidate : candidates) {
			//设置Scope
			ScopeMetadata scopeMetadata = this.scopeMetadataResolver.resolveScopeMetadata(candidate);
			candidate.setScope(scopeMetadata.getScopeName());
			//生成beanName
			String beanName = this.beanNameGenerator.generateBeanName(candidate, this.registry);
			if (candidate instanceof AbstractBeanDefinition) {
				//默认值处理
				postProcessBeanDefinition((AbstractBeanDefinition) candidate, beanName);
			}
			if (candidate instanceof AnnotatedBeanDefinition) {
				//@Lazy @Primary @DependsOn @Role @Description这些注解支持
				AnnotationConfigUtils.processCommonDefinitionAnnotations((AnnotatedBeanDefinition) candidate);
			}
			//bean冲突校验
			if (checkCandidate(beanName, candidate)) {
				BeanDefinitionHolder definitionHolder = new BeanDefinitionHolder(candidate, beanName);
				definitionHolder =
						AnnotationConfigUtils.applyScopedProxyMode(scopeMetadata, definitionHolder, this.registry);
				beanDefinitions.add(definitionHolder);
				//注册bean
				registerBeanDefinition(definitionHolder, this.registry);
			}
		}
	}
	return beanDefinitions;
}

关键在于findCandidateComponents是如何找到这些候选的BeanDefinition的呢?接下来就会走到scanCandidateComponents中,接下来我们debug看下:
1

从代码中可以看到,spring将传入的"com.example.spring.beans3"解析成了"classpath*:com/example/spring/beans3/**/*.class",并通过getResourcePatternResolver().getResources(packageSearchPath)来获取所有的resource,那么,我们看下getResources是如何处理classpath*:com/example/spring/beans3/**/*.class的。接下来,程序执行到findPathMatchingResources中,在findPathMatchingResources中通过getResource方法来返回Resource[]。
在这里插入图片描述
那么getResource里面是什么呢?在getResource中,最后会调用到doFindAllClassPathResources,如下图:
在这里插入图片描述
这段代码好眼熟…Enumeration<URL> resourceUrls = (cl != null ? cl.getResources(path) : ClassLoader.getSystemResources(path));跟我的实现方式是一样的,也是先把package转换为资源路径,然后通过classLoader.getResources的方式来获取resource。接下来遍历这些resources,不同类型的resource走不同的逻辑,就像下面这样。
在这里插入图片描述
当所有的resources都获取到了之后,就开始遍历所有的resource。如下图:
在这里插入图片描述
这里spring的手法就比较高端了,spring通过读取resource中的class文件的字节码,生成了一个叫MetadataReader的对象。这个MetadaReader并不是class对象,但是可以读取到class上所有的元数据信息。这是因为spring使用了ASM技术,用流的方式读取了class文件。然后就是创建ScannedGenericBeanDefinition并返回了。

反思与总结

spring中的包扫描虽然整体逻辑并不复杂,但是细节还是很多的。比如它处理了通配符**/*.class、处理了不同协议的url、在最终读取class信息时使用了ASM技术、还支持自定义的过滤器等。spring在能扩展的地方都给我们留出了扩展点,但是在使用起来却是很方便,这一点还是很厉害的。

  • 12
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值