自定义类加载器

文章介绍了如何在Java中创建一个名为JarLoader的自定义类加载器,它打破了传统的双亲委派机制,采用逆向委派策略,同时提供了类缓存和ClassLoaderSwapper工具类,以解决jar包冲突问题。
摘要由CSDN通过智能技术生成

java中自定义类加载器,并将双亲委派改为逆向双亲委派

自定义类加载器JarLoader:

package cn.ac.iscas.dmo.common.tools.core.classloader;

import org.apache.commons.collections4.MapUtils;

import java.io.*;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.CodeSource;
import java.security.ProtectionDomain;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * 提供Jar隔离的加载机制,会把传入的路径、及其子路径、以及路径中的jar文件加入到class path。
 * 破坏双亲委派机制,改为逆向
 *
 * @author admin*/
@SuppressWarnings({"rawtypes", "unused", "unchecked"})
public class JarLoader extends URLClassLoader {
    @SuppressWarnings("AlibabaThreadLocalShouldRemove")
    private static final ThreadLocal<URL[]> THREAD_LOCAL = new ThreadLocal<>();
    @SuppressWarnings("FieldMayBeFinal")
    private URL[] allUrl;
    @SuppressWarnings("FieldMayBeFinal")
    private boolean useCache;
    @SuppressWarnings({"FieldCanBeLocal", "FieldMayBeFinal", "unused"})
    private String[] paths;
    @SuppressWarnings("FieldMayBeFinal")
    private String pathStr;
    private String dbType;

    /**缓存当前类加载器加载的类*/
    @SuppressWarnings("MismatchedQueryAndUpdateOfCollection")
    private final Map<String, Class<?>>  jarLoaderClasses = new ConcurrentHashMap<>();

    /**
     * 缓存对应类型的加载的类
     * */
    public static Map<String, Map<String, Class>> typeJarLoaderClasses = new ConcurrentHashMap<>();

    /**缓存对象*/
    private static final Map<String, Map<String, byte[]>> CLASS_BYTES = new ConcurrentHashMap<>();

    /**ProtectionDomain 缓存*/
    private static final Map<String, ProtectionDomain> PROTECTION_DOMAIN_MAP = new ConcurrentHashMap<>();


    public JarLoader(String[] paths, boolean useCache, String type) {
        this(paths, JarLoader.class.getClassLoader(), useCache);
        this.dbType = type;
    }

    public JarLoader(String[] paths, ClassLoader parent, boolean useCache) {
        super(getUrls(paths), parent);
        //暂时先这样
        allUrl = THREAD_LOCAL.get();
        this.useCache = useCache;
        this.paths = paths;
        pathStr = String.join(";", paths);
    }

    public JarLoader(String[] paths) {
        this(paths, JarLoader.class.getClassLoader(), false);
    }

    public JarLoader(String[] paths, ClassLoader parent) {
        this (paths, parent, false);
    }

    /**
     * 清除某个路径下的缓存,
     * 可适用于不想重启服务,但更新了外部插件的jar包的情况下调用
     * */
    public static void clearCache(String[] paths) {
        String pathStr = String.join(";", paths);
        CLASS_BYTES.remove(pathStr);
    }


    /**
     * 加载class文件,方便加载的方法
     * */
    public static Class<?> outerLoadClass(String name) throws ClassNotFoundException {
        return Thread.currentThread().getContextClassLoader().loadClass(name);
    }

    private static URL[] getUrls(String[] paths) {
        if (null == paths || 0 == paths.length) {
            throw new RuntimeException("jar包路径不能为空.");
        }

        List<File> jarFiles = new ArrayList<>();
        List<String> dirFiles = new ArrayList<>();
        for (String path : paths) {
            File file = new File(path);
            if (file.isFile()) {
                jarFiles.add(file);
            } else {
                dirFiles.add(path);
            }
        }

        List<String> dirs = new ArrayList<>();
        for (String path : dirFiles) {
            dirs.add(path);
            JarLoader.collectDirs(path, dirs);
        }

        List<URL> urls = new ArrayList<>();
        for (String path : dirs) {
            urls.addAll(doGetUrls(path));
        }

        for (File jarFile : jarFiles) {
            try {
                URL url = jarFile.toURI().toURL();
                urls.add(url);
            } catch (Exception e) {
                throw new RuntimeException("系统加载jar包出错", e);
            }

        }

        URL[] urls1 = urls.toArray(new URL[0]);
        THREAD_LOCAL.set(urls1);
        return urls1;
    }

    private static void collectDirs(String path, List<String> collector) {
        if (null == path || "".equalsIgnoreCase(path)) {
            return;
        }

        File current = new File(path);
        if (!current.exists() || !current.isDirectory()) {
            return;
        }

        for (File child : Objects.requireNonNull(current.listFiles())) {
            if (!child.isDirectory()) {
                continue;
            }

            collector.add(child.getAbsolutePath());
            collectDirs(child.getAbsolutePath(), collector);
        }
    }

    private static List<URL> doGetUrls(final String path) {
        if (null == path || "".equalsIgnoreCase(path)) {
            throw new RuntimeException("jar包路径不能为空.");
        }
        File jarPath = new File(path);

        if (!jarPath.exists() || !jarPath.isDirectory()) {
            throw new RuntimeException("jar包路径必须存在且为目录.");
        }

        /* set filter */
        FileFilter jarFilter = pathname -> pathname.getName().endsWith(".jar");

        /* iterate all jar */
        File[] allJars = new File(path).listFiles(jarFilter);
        assert allJars != null;
        List<URL> jarUrls = new ArrayList<>(allJars.length);

        for (File allJar : allJars) {
            try {
                jarUrls.add(allJar.toURI().toURL());
            } catch (Exception e) {
                throw new RuntimeException("系统加载jar包出错", e);
            }
        }
        return jarUrls;
    }
    /**破坏双亲委派模型,采用逆向双亲委派*/
    @Override
    public Class<?> loadClass(String name) throws ClassNotFoundException {
        //读取缓存
//        Class<?> aClass = jarLoaderClasses.get(name);
        Class<?> aClass = null;
        if (typeJarLoaderClasses.containsKey(dbType)) {
            Map<String, Class> stringClassMap = typeJarLoaderClasses.getOrDefault(dbType, MapUtils.EMPTY_SORTED_MAP);
            aClass = stringClassMap.get(name);
        }

        if (aClass == null) {
            aClass = findClass(name);
        }

        if (aClass == null) {
            return super.loadClass(name);
        } else {
            // 放入缓存
            jarLoaderClasses.put(name, aClass);
            // 放入带数据库类型的缓存
            typeJarLoaderClasses.computeIfAbsent(dbType, key -> new ConcurrentHashMap<>(32)).put(name, aClass);
        }
        return aClass;
    }

    @Override
    public Class<?> findClass(String name) {
        //如果开启了缓存,查看class文件对应字节数组有没有缓存起来,如果有缓存,直接使用缓存的字节数组
        if (useCache) {
            synchronized (name.intern()) {
                Map<String, byte[]> cacheMap = CLASS_BYTES.get(pathStr);
                if (MapUtils.isNotEmpty(cacheMap)) {
                    byte[] bytes = cacheMap.get(name);
                    if (bytes != null) {
                       Class<?> aClassx = this.defineClass(name, bytes, 0, bytes.length, PROTECTION_DOMAIN_MAP.get(name));
                       if (aClassx != null) {
                           System.out.println("读取到缓存.....");
                           return aClassx;
                       }
                    }
                }
            }
        }

        Class<?> aClass = null;
        if (allUrl != null) {
            String classPath = name.replace(".", "/");
            classPath = classPath.concat(".class");

            for (URL url : allUrl) {
                byte[] data;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                InputStream is = null;
                try {
                    File file = new File(url.toURI());
                    if (file.exists()) {
                        @SuppressWarnings("resource") JarFile jarFile = new JarFile(file);
                        JarEntry jarEntry = jarFile.getJarEntry(classPath);
                        if (jarEntry != null) {
                            is = jarFile.getInputStream(jarEntry);
                            int c;
                            byte[] buff = new byte[4096];
                            while (-1 != (c = is.read(buff))) {
                                baos.write(buff, 0, c);
                            }
                            data = baos.toByteArray();
                            CodeSource codeSource = new CodeSource(url, (Certificate[]) null); // 获取类的代码来源
                            ProtectionDomain protectionDomain = new ProtectionDomain(codeSource, null, this, null);
                            aClass = this.defineClass(name, data, 0, data.length, protectionDomain);

                            synchronized (name.intern()) {
                                if (useCache && aClass != null) {
                                    System.out.println("写入缓存---");
                                    Map<String, byte[]> classByteMap = CLASS_BYTES.get(pathStr);
                                    if (MapUtils.isEmpty(classByteMap)) {
                                        classByteMap = new ConcurrentHashMap<>(2);
                                        CLASS_BYTES.put(pathStr, classByteMap);
                                    }
                                    CLASS_BYTES.get(pathStr).put(name, data);
                                    PROTECTION_DOMAIN_MAP.put(name, protectionDomain);
                                }
                            }

                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    try {
                        if (is != null) {
                            is.close();
                        }
                        baos.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }

            }
        }
        return aClass;
    }

}

类加载器切换工具类:

package cn.ac.iscas.dmo.common.tools.core.classloader;

/**
 *
 * 为避免jar冲突,比如hbase可能有多个版本的读写依赖jar包
 * 就需要脱离当前classLoader去加载这些jar包,执行完成后,又退回到原来classLoader上继续执行接下来的代码
 * @author admin
 */
public final class ClassLoaderSwapper {
    private ClassLoader storeClassLoader = null;

    private ClassLoaderSwapper() {
    }

    public static ClassLoaderSwapper newCurrentThreadClassLoaderSwapper() {
        return new ClassLoaderSwapper();
    }

    /**
     * 保存当前classLoader,并将当前线程的classLoader设置为所给classLoader
     *
     * @param classLoader 类加载器
     * @return ClassLoader
     */
    @SuppressWarnings("UnusedReturnValue")
    public ClassLoader setCurrentThreadClassLoader(ClassLoader classLoader) {
        this.storeClassLoader = Thread.currentThread().getContextClassLoader();
        Thread.currentThread().setContextClassLoader(classLoader);
        return this.storeClassLoader;
    }

    /**
     * 将当前线程的类加载器设置为保存的类加载
     * @return ClassLoader
     */
    @SuppressWarnings("UnusedReturnValue")
    public ClassLoader restoreCurrentThreadClassLoader() {
        ClassLoader classLoader = Thread.currentThread()
                .getContextClassLoader();
        Thread.currentThread().setContextClassLoader(this.storeClassLoader);
        return classLoader;
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值