package com.test;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.StringUtils;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.net.URL;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.List;
@SpringBootTest
public class EnumsTest{
private static final String PACKAGE_NAME = "com.test.enums";
private static final List<String> filterClazzList = new ArrayList<String>();
static {
filterClazzList.add("");
}
@Test
public void doTest() {
List<Class<?>> allClass = getClasses(PACKAGE_NAME);
for (Class classes : allClass) {// 循环反射执行所有类
try {
methodInvoke(classes);
} catch (Exception e) {
e.printStackTrace();
}
}
}
private void methodInvoke(Class clazz) {
Method[] declaredMethods = clazz.getDeclaredMethods();
Object[] enumConstants = clazz.getEnumConstants();
for (Method method : declaredMethods) {
String methodName = method.getName();
Class<?>[] parameterTypes = method.getParameterTypes();
if (StringUtils.substringMatch(methodName, 0,"set")) {//StringUtils.startWith(methodName,"set")
try {
method.invoke(enumConstants[0],adaptorGenObj(parameterTypes[0]));
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
} catch (InstantiationException e) {
e.printStackTrace();
}
}
if (StringUtils.substringMatch(methodName, 0,"get")) {
int count = method.getParameterCount();
try {
if (count == 0) {
method.invoke(enumConstants[0]);
} else if (count == 1) {
method.invoke(null,adaptorGenObj(parameterTypes[0]));
}else{
method.invoke(null,adaptorGenObj(parameterTypes[0]),adaptorGenObj(parameterTypes[1]));
}
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
} catch (InstantiationException e) {
e.printStackTrace();
}
}
}
}
private Object adaptorGenObj(Class<?> clazz)throws IllegalArgumentException, InvocationTargetException, InstantiationException, IllegalAccessException {
if (null == clazz) {
return null;
}
if ("int".equals(clazz.getName())) {
return 1;
} else if ("char".equals(clazz.getName())) {
return 'x';
} else if ("Integer".equals(clazz.getName())) {
return 1;
} else if ("boolean".equals(clazz.getName())) {
return true;
} else if ("double".equals(clazz.getName())) {
return 1.0;
} else if ("float".equals(clazz.getName())) {
return 1.0f;
} else if ("long".equals(clazz.getName())) {
return 1l;
} else if ("byte".equals(clazz.getName())) {
return 0xFFFFFFFF;
} else if ("java.lang.Class".equals(clazz.getName())) {
return this.getClass();
} else if ("java.math.BigDecimal".equals(clazz.getName())) {
return new BigDecimal(1);
} else if ("java.lang.String".equals(clazz.getName())) {
return "333";
} else if ("java.util.Hashtable".equals(clazz.getName())) {
return new Hashtable();
}else if ("java.util.List".equals(clazz.getName())) {
return new ArrayList();
}else {
return null;
}
}
/**
* 获取包下的所有类
* @param packageName
* @return
*/
private List<Class<?>> getClasses(String packageName) {
// 第一个class类的集合
List<Class<?>> classes = new ArrayList<Class<?>>();
// 是否循环迭代
boolean recursive = true;
// 获取包的名字 并进行替换
String packageDirName = packageName.replace('.', '/');
// 定义一个枚举的集合 并进行循环来处理这个目录下的things
Enumeration<URL> dirs;
try {
dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
// 循环迭代下去
while (dirs.hasMoreElements()) {
// 获取下一个元素
URL url = dirs.nextElement();
// 得到协议的名称
String protocol = url.getProtocol();
// 如果是以文件的形式保存在服务器上
if ("file".equals(protocol)) {
// 获取包的物理路径
String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
// 以文件的方式扫描整个包下的文件 并添加到集合中
findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
}
}
} catch (IOException e) {
e.printStackTrace();
}
return classes;
}
/**
* 查找包下的文件
* @param packageName
* @param packagePath
* @param recursive
* @param classes
*/
private void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive,
List<Class<?>> classes) {
// 获取此包的目录 建立一个File
File dir = new File(packagePath);
// 如果不存在或者 也不是目录就直接返回
if (!dir.exists() || !dir.isDirectory()) {
return;
}
// 如果存在 就获取包下的所有文件 包括目录
File[] dirfiles = dir.listFiles(new FileFilter() {
// 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
public boolean accept(File file) {
return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
}
});
// 循环所有文件
for (File file : dirfiles) {
// 如果是目录 则递归继续扫描
if (file.isDirectory()) {
findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive,
classes);
} else {
// 如果是java类文件 去掉后面的.class 只留下类名
String className = file.getName().substring(0, file.getName().length() - 6);
String pakClazzName = packageName + '.' + className;
//过滤掉不需要的类
if (filterClazzList.contains(pakClazzName)) {
continue;
}
try {
// 添加到集合中去
classes.add(Class.forName(pakClazzName));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
}
参考: