JAVA spring mvc 项目单元测试基类及单元测试自动生成工具类

JAVA spring mvc 项目单元测试基类及单元测试自动生成工具类

测试类基类

package com.fzy.gdaba.test.web.utils;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.json.JSONArray;
import org.json.JSONException;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.http.*;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RequestMapping;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
 * ClassName: SupperTest <br>
 * Description: 测试类-基类
 *
 * @author sunlight
 * @date 2020/2/26 17:55
 * @since JDK 1.8
 */
@RunWith(SpringJUnit4ClassRunner.class)
// 使用spring-boot的方式启动单元测试,并且使用默认的端口 模拟http请求
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.DEFINED_PORT)
// 测试类方法执行顺序  按代码顺序执行
@FixMethodOrder(MethodSorters.JVM)
// 启用 @ConfigurationProperties注解
@EnableConfigurationProperties({ServerProperties.class})
public abstract class SupperTest {

    @Autowired
    private ServerProperties serverProperties;
    HttpHeaders header = new HttpHeaders();
    // 日志类
    private final static Logger LOG = LoggerFactory.getLogger(SupperTest.class);
    private static String info = "";

    private final TestRestTemplate TEST_REST_TEMPLATE = new TestRestTemplate();
    protected final HttpMethod POST = HttpMethod.POST;
    protected final HttpMethod GET = HttpMethod.GET;
    protected final HttpMethod PUT = HttpMethod.PUT;
    protected final HttpMethod DELETE = HttpMethod.DELETE;

    // 用户登陆名
    protected String loginCode = "sunlight-yb";
    // 用户登陆密码
    protected String password = "123";
    protected String token;

    /** 执行 */
    protected static final boolean RUN = true;
    /** 忽略 */
    protected static final boolean IGNORE = false;

    /**
     * Description: 返回请求路径[clazz不为空会把Controller请求路径进行拼接]
     *
     * @param clazz Controller
     * @return 返回请求路径[clazz不为空会把Controller请求路径进行拼接]
     * @author sunlight
     */
    protected String getLocalhostUrl(Class<?> clazz) {
        int port = serverProperties.getPort() == null ? 8080 : serverProperties.getPort();
        if (clazz != null && clazz.isAnnotationPresent(RequestMapping.class)) {
            return "http://localhost:" + port + clazz.getAnnotation(RequestMapping.class).value()[0];
        }
        return "http://localhost:".concat(String.valueOf(port));
    }

    /**
     * Description: 使用反射返回响应信息
     *
     * @param remark 备注
     * @param clazz 反射类
     * @param method 请求方式[GET, POST, PUT, DELETE;]
     * @param apiUrl api地址
     * @param request 入参
     * @param urlVariables url参数
     * @param <T> 入参类型泛型
     * @param <R> 出参类型泛型
     * @return R 出参类型
     * @author sunlight
     */
    protected <T, R> R execute(String remark, Class<R> clazz, HttpMethod method, String apiUrl, T request
            , Object... urlVariables) {
        String token = getToken();
        long start = System.currentTimeMillis();
        LOG.info("\n\n//======================== {}:开始测试[请求地址<{}>] ========================//\n\n", remark, apiUrl);
        LOG.info("请求参数:{}\n", Objects.requireNonNull(prettyFormat(toJSONString(request))));
        LOG.info("地址参数:{}\n", Objects.requireNonNull(prettyFormat(toJSONString(urlVariables))));
        info += "//======================== " + remark + ":开始测试[请求地址<"+ apiUrl +">] ========================//\n";
        info += "请求参数:" + Objects.requireNonNull(prettyFormat(toJSONString(request))) + "\n";
        info += "地址参数:" + Objects.requireNonNull(prettyFormat(toJSONString(urlVariables))) + "\n";
        // 创建Http请求头部
        HttpHeaders header = new HttpHeaders();
        // 设置token
        header.add(HttpHeaders.AUTHORIZATION, token);
        // 设置内容类型
        header.setContentType(MediaType.APPLICATION_JSON);
        // 创建Http请求实体
        HttpEntity<T> httpEntity = new HttpEntity<>(request, header);
        // 执行Http请求
        ResponseEntity<R> result = TEST_REST_TEMPLATE.exchange(apiUrl, method, httpEntity, clazz, urlVariables);
        if (result.getStatusCodeValue() != 200) {
            LOG.info(prettyFormat(toJSONString(result.getBody())));
            info += "返回值:" + Objects.requireNonNull(prettyFormat(toJSONString(result.getBody()))) + "\n";
            printInfo();
        }
        LOG.info(prettyFormat(toJSONString(result.getBody())));
        info += "返回值:" + Objects.requireNonNull(prettyFormat(toJSONString(result.getBody()))) + "\n";
        long end = System.currentTimeMillis();
        LOG.info("\n\n//======================== {}:完成测试-耗时{}ms ========================//\n\n", remark, end - start);
        info += "//======================== " + remark + ":完成测试-耗时" + (end - start) + "ms ========================//\n";
        // 返回body信息
        return result.getBody();
    }

    /**
     * Description: 获取Token
     *
     * @return java.lang.String
     * @author sunlight
     */
    private String getToken() {
        if (null != this.token) {
            return token;
        }
        return getToken(this.loginCode, this.password);
    }

    /**
     * Description: 获取token
     *
     * @return java.lang.String
     * @author sunlight
     */
    private String getToken(String mobile, String password) {
        info += "获取Token信息:mobile - " + mobile + " | password - " + password + "\n";
        String appType = "pc";
        String language = "zh-cn";
        // TOKEN请求地址
        String apiUrl = "";
        // 创建Http请求头部
        HttpHeaders header = new HttpHeaders();
        // 设置token
        header.add(HttpHeaders.AUTHORIZATION, null);
        // 设置内容类型
        header.setContentType(MediaType.APPLICATION_JSON);
        // 创建Http请求实体
        HttpEntity<String> httpEntity = new HttpEntity<>(null, header);
        // 执行Http请求
        ResponseEntity<String> result = TEST_REST_TEMPLATE.exchange(apiUrl, HttpMethod.GET, httpEntity, String.class
                , mobile, password, appType, language);
        // 返回body信息
        String data = result.getBody();
        String token = null;
        try {
            token = new JSONArray(data).getJSONObject(0).getString("token");
        } catch (JSONException e) {
            e.printStackTrace();
        }
        return token;
    }

    public static String toJSONString(Object obj) {
        if (StringUtils.isEmpty(obj)) {
            return null;
        } else {
            ObjectMapper mapper = new ObjectMapper();
            try {
                // 对象转字符串
                return mapper.writeValueAsString(obj);
            } catch (JsonProcessingException var2) {
                throw new IllegalArgumentException(" clazz to json fail.");
            }
        }
    }

    /**
     * Description: 格式化输出
     *
     * @param json 需格式化的目标
     * @return 格式化后的目标
     * @author sunlight
     */
    private String prettyFormat(String json) {
        if (StringUtils.isEmpty(json)) {
            return "null";
        }
        ObjectMapper mapper = new ObjectMapper();
        String ret;
        try {
            Object obj = mapper.readValue(json, Object.class);
            // writerWithDefaultPrettyPrinter -- 格式化、writeValueAsString -- 对象转字符串
            ret = mapper.writerWithDefaultPrettyPrinter().writeValueAsString(obj);
        } catch (IOException e) {
            return null;
        }
        return System.getProperty("line.separator").concat(ret);
    }

    /**
     * Description: 获取完整请求地址
     *
     * @param url controller 中方法的请求地址
     * @return 完整请求地址 eg:getLocalhostUrl(Controller.class).concat(url);
     * @author sunlight
     */
    protected abstract String getApiUrl(String url);



    /**
     * Description: 执行测试方法
     *
     * @return void
     * @author sunlight
     */
    @Test
    public abstract void execute();

    /**
     * Description: 数组转list
     *
     * @param sources 源数据
     * @return java.util.List<T> list数据
     * @author sunlight
     */
    protected <T> List<T> arraysAsList(T[] sources) {
        return sources == null || sources.length < 1 ? null : Arrays.asList(sources);
    }

    /**
     * Description: 打印本轮测试信息
     *
     * @return void
     * @author sunlight
     */
    protected void printInfo() {
        throw new IllegalArgumentException("//======================== 本次测试完成,测试结果如下: ========================//\n" + info);
    }
}

测试基类用例

package com.fzy.gdaba.test.web.restapi;

import com.fzy.gdaba.test.web.utils.SupperTest;
import javax.servlet.http.HttpServletRequest;
import java.lang.String;
import java.lang.Long;
import com.fzy.gdaba.web.restapi.BlogContentRestApi;

 /**
* ClassName: Test_BlogContentRestApi <br>
* Description: 文章详情相关接口
*
* @author sunlight
* @date 2021/03/25 11:22
* @since JDK 1.8.0_181
*/
public class Test_BlogContentRestApi extends SupperTest {

	@Override
	protected String getApiUrl(String url) {
		return getLocalhostUrl(BlogContentRestApi.class) + url;
	}

	@Override
	public void execute() {
		// false不执行,true执行
		// 通过Uid获取博客内容
		getBlogByUid(false);
		// 通过Uid获取博客内容
		getBlogByZhuangtiUid(false);
		// 通过Uid获取博客点赞数
		getBlogPraiseCountByUid(false);
		// 根据BlogUid获取相关的博客
		getSameBlogByBlogUid(false);
		// 根据标签Uid获取相关的博客
		getSameBlogByTagUid(false);
		// 通过Uid给博客点赞
		praiseBlogByUid(false);

		// 打印测试信息
		printInfo();
	}

	// 通过Uid获取博客内容
	private void getBlogByUid(boolean execute) {
		if (!execute) {return;}
		String uid = null;
		execute("通过Uid获取博客内容", String.class, GET, getApiUrl("/getBlogByUid"), null);
	}

	// 通过Uid获取博客内容
	private void getBlogByZhuangtiUid(boolean execute) {
		if (!execute) {return;}
		String uid = null;
		execute("通过Uid获取博客内容", String.class, GET, getApiUrl("/getBlogByZhuangtiUid"), null);
	}

	// 通过Uid获取博客点赞数
	private void getBlogPraiseCountByUid(boolean execute) {
		if (!execute) {return;}
		String uid = null;
		execute("通过Uid获取博客点赞数", String.class, GET, getApiUrl("/getBlogPraiseCountByUid"), null);
	}

	// 根据BlogUid获取相关的博客
	private void getSameBlogByBlogUid(boolean execute) {
		if (!execute) {return;}
		HttpServletRequest request = null;
		String blogUid = null;
		execute("根据BlogUid获取相关的博客", String.class, GET, getApiUrl("/getSameBlogByBlogUid"), null);
	}

	// 根据标签Uid获取相关的博客
	private void getSameBlogByTagUid(boolean execute) {
		if (!execute) {return;}
		String tagUid = null;
		Long currentPage = null;
		Long pageSize = null;
		execute("根据标签Uid获取相关的博客", String.class, GET, getApiUrl("/getSameBlogByTagUid"), null);
	}

	// 通过Uid给博客点赞
	private void praiseBlogByUid(boolean execute) {
		if (!execute) {return;}
		String uid = null;
		execute("通过Uid给博客点赞", String.class, GET, getApiUrl("/praiseBlogByUid"), null);
	}

}

单元测试类生成工具类

该工具类需与测试基类一起使用,需要项目引用了Spring MVC以及swagger,生成测试类效果请看:测试基类用例

package com.fzy.gdaba.test.web.utils;

import com.fzy.gdaba.web.restapi.BlogContentRestApi;
import com.fzy.gdaba.web.restapi.ClassifyRestApi;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.assertj.core.util.Sets;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.function.Predicate;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * ClassName: GenerationTestCode <br>
 * Description: 生成测试代码
 *
 * @author sunlight
 * @date 2020/12/3 15:13
 * @since JDK 1.8
 */
public class GenerationTestCode {

    public static void main(String[] args) {
        new GenerationTestCode(
            // 代码存放的包路径
            "com.fzy.gdaba.test.web.restapi"
            // 包前缀 com.x3
            , "com.fzy"
            // 要扫描的controller路径
            , "com.fzy.gdaba.web.restapi"
            // 全局变量 -- 可忽略
            , new String[]{"billsId"}
            // 可指定生成的类 -- 可忽略,此属性设置后 scanPack 失效
            , BlogContentRestApi.class, ClassifyRestApi.class
        ).generate();
    }

    /**
     * Description: 生成代码
     *
     * @return void
     * @author sunlight
     */
    public void generate() {
        // 如果指定了需要生成测试类的 controller 则只对指定的controller生成
        Set<Class<?>> classes = controllers != null && controllers.length > 0 ?
            Sets.newHashSet(Arrays.asList(controllers)) : ClassUtil.getClasses(scanPack);
        // 遍历Controller
        Predicate<Class<?>> effectiveController = controller -> null != controller
            && (controller.isAnnotationPresent(RestController.class) || controller.isAnnotationPresent(Controller.class));
        classes.stream().filter(effectiveController)
            .forEach(controller -> {
                // 类
                StringBuilder classStr = new StringBuilder();
                // 引用类
                Set<String> imports = new HashSet<>();
                // 添加Controller引用
                imports.add(controller.getName());
                // 固定方法
                StringBuilder fixedMehthodStr = new StringBuilder();
                // 方法
                StringBuilder mehthodStr = new StringBuilder();
                // 方法调用代码
                StringBuilder executeStr = new StringBuilder();
                // 控制器描述
                String apiTags = null;
                if (controller.isAnnotationPresent(Api.class)) {
                    if (isNotBlankAndNull(controller.getAnnotation(Api.class).value())) {
                        apiTags = controller.getAnnotation(Api.class).value();
                    } else {
                        apiTags = controller.getAnnotation(Api.class).tags()[0];
                    }
                }
                final String testClassName = "Test_" + controller.getSimpleName();
                // 遍历方法 -- 控制需要生成的方法
                Predicate<Method> isMethod = item -> item.isAnnotationPresent(RequestMapping.class)
                    || item.isAnnotationPresent(PostMapping.class)
                    || item.isAnnotationPresent(GetMapping.class)
                    || item.isAnnotationPresent(PutMapping.class)
                    || item.isAnnotationPresent(DeleteMapping.class);
                Arrays.stream(controller.getDeclaredMethods())
                    .filter(isMethod)
                    .sorted(Comparator.comparing(Method::getName))
                    .forEach(method -> {
                        String url = "";
                        String requestMethod = "";
                        String apiOperation = null;
                        if (method.isAnnotationPresent(ApiOperation.class)) {
                            if (isNotBlankAndNull(method.getAnnotation(ApiOperation.class).value())) {
                                apiOperation = method.getAnnotation(ApiOperation.class).value();
                            } else {
                                apiOperation = method.getAnnotation(ApiOperation.class).notes();
                            }
                        }

                        // 添加出参引用
                        if (method.getReturnType().getName().contains(".")) {
                            if (List.class.equals(method.getReturnType())) {
                                // List添加泛型引用
                                imports.add(ClassUtil.getListActualTypeArgumentsTypeName(method));
                            } else {
                                imports.add(method.getReturnType().getName());
                            }
                        }

                        String responseClass = "void".equals(method.getReturnType().getSimpleName()) ?
                            // 无返回值类型
                            "String.class" :
                            List.class.equals(method.getReturnType()) ?
                                // List 类型
                                ClassUtil.getListActualTypeArgumentsSimpleTypeName(method) + "[].class" :
                                // 对象类型
                                method.getReturnType().getSimpleName() + ".class";
                        StringBuilder requestBodyClass = new StringBuilder("null");
                        // 方法调用代码
                        executeStr.append(Character.LINE_BREAK + Character.TAB + Character.TAB)
                            .append(Character.SINGLE_LINE_COMMENTS)
                            .append(apiOperation).append(Character
                            .LINE_BREAK + Character.TAB + Character.TAB)
                            .append(method.getName()).append("(false);");

                        // 方法
                        mehthodStr.append(Character.LINE_BREAK + Character.TAB).append(Character.SINGLE_LINE_COMMENTS).append(apiOperation)
                            .append(Character.LINE_BREAK + Character.TAB)
                            .append("private void ").append(method.getName()).append("(boolean execute) {");

                        // 判断是否执行
                        mehthodStr.append(Character.LINE_BREAK + Character.TAB + Character.TAB).append("if (!execute) {return;}");
                        // 处理请求参数
                        if (method.getParameters() != null && method.getParameters().length > 0) {
                            for (Parameter parameter : method.getParameters()) {
                                // 添加入参引用
                                imports.add(parameter.getType().getName());

                                // 处理 RequestBody 注解参数 注意:只处理第一个
                                if (parameter.isAnnotationPresent(RequestBody.class) && "null".equals(requestBodyClass.toString())) {
                                    requestBodyClass = new StringBuilder(parameter.getName());
                                }

                                // 集合
                                if (List.class.equals(parameter.getType())) {
                                    // 添加List引用
                                    imports.add(List.class.getName());

                                    // 添加ArrayList引用
                                    imports.add(ArrayList.class.getName());

                                    String typeName = "T";
                                    if (parameter.getParameterizedType() instanceof ParameterizedType) {
                                        ParameterizedType type = (ParameterizedType) parameter.getParameterizedType();
                                        typeName = type.getActualTypeArguments()[0].getTypeName().lastIndexOf(".") == -1 ?
                                            type.getActualTypeArguments()[0].getTypeName() :
                                            type.getActualTypeArguments()[0].getTypeName().substring(type.getActualTypeArguments()[0].getTypeName().lastIndexOf(".") + 1);
                                        // 添加List泛型引用
                                        imports.add(type.getActualTypeArguments()[0].getTypeName());
                                    }
                                    parameter.getParameterizedType();
                                    mehthodStr.append(Character.LINE_BREAK + Character.TAB + Character.TAB)
                                        .append("List<").append(typeName).append("> ")
                                        .append(parameter.getName())
                                        .append(" = new ArrayList<>();");
                                    continue;
                                }

                                // 本项目的入参实体特殊处理
                                if (parameter.getType().getName().startsWith(projectPackPrefix)) {
                                    String parameterName = parameter.getName();
                                    mehthodStr.append(Character.LINE_BREAK + Character.TAB + Character.TAB)
                                        .append(parameter.getType().getSimpleName()).append(" ")
                                        .append(parameter.getName())
                                        .append(" = new ").append(parameter.getType().getSimpleName()).append("();");
                                    mehthodStr.append(ClassUtil.getAllSet(parameter.getType(), parameterName, Character.LINE_BREAK + Character.TAB + Character.TAB, this.globalFinalName));
                                    continue;
                                }

                                // 普通对象类型
                                mehthodStr.append(Character.LINE_BREAK + Character.TAB + Character.TAB)
                                    .append(parameter.getType().getSimpleName()).append(" ")
                                    .append(parameter.getName())
                                    .append(" = null;");
                            }

                            // 处理地址入参[PathVariable]
                            for (Parameter parameter : Objects.requireNonNull(method.getParameters())) {
                                if (parameter.isAnnotationPresent(PathVariable.class)) {
                                    requestBodyClass.append(", ").append(parameter.getName());
                                }
                            }
                        }

                        // 判断请求类型
                        if (method.isAnnotationPresent(PostMapping.class)) {
                            url = method.getAnnotation(PostMapping.class).value()[0];
                            requestMethod = "POST";
                        }
                        if (method.isAnnotationPresent(GetMapping.class)) {
                            url = method.getAnnotation(GetMapping.class).value()[0];
                            requestMethod = "GET";
                        }
                        if (method.isAnnotationPresent(PutMapping.class)) {
                            url = method.getAnnotation(PutMapping.class).value()[0];
                            requestMethod = "PUT";
                        }
                        if (method.isAnnotationPresent(DeleteMapping.class)) {
                            url = method.getAnnotation(DeleteMapping.class).value()[0];
                            requestMethod = "DELETE";
                        }
                        if (method.isAnnotationPresent(RequestMapping.class)) {
                            url = method.getAnnotation(RequestMapping.class).value()[0];
                            requestMethod = "POST";
                        }

                        mehthodStr.append(Character.LINE_BREAK + Character.TAB + Character.TAB).append("execute(\"").append(apiOperation).append("\", ")
                            .append(responseClass).append(", ")
                            .append(requestMethod).append(", ")
                            .append("getApiUrl(\"").append(url).append("\"), ")
                            .append(requestBodyClass)
                            .append(");");
                        mehthodStr.append(Character.LINE_BREAK + Character.TAB).append("}\n");
                    });

                // 注释
                classStr.append("package ").append(packageUrl).append(";")
                    .append(Character.LINE_BREAK).append(Character.LINE_BREAK).append("import ").append(SupperTest.class.getName()).append(";")
                    .append(getImports(imports))
                    .append(Character.LINE_BREAK).append(Character.LINE_BREAK).append(" /**")
                    .append(Character.LINE_BREAK).append("* ClassName: ").append(testClassName).append(" <br>")
                    .append(Character.LINE_BREAK).append("* Description: ").append(apiTags)
                    .append(Character.LINE_BREAK).append("*")
                    // 作者
                    .append(Character.LINE_BREAK).append("* @author ").append(System.getenv("USERNAME"))
                    // 日期
                    .append(Character.LINE_BREAK).append("* @date ").append(new SimpleDateFormat("yyyy/MM/dd HH:mm").format(new Date()))
                    // Java 版本
                    .append(Character.LINE_BREAK).append("* @since ").append("JDK ").append(System.getProperty("java.version"))
                    .append(Character.LINE_BREAK).append("*/")
                    .append(Character.LINE_BREAK).append("public class ").append(testClassName).append(" extends ").append(SupperTest.class.getSimpleName()).append(" {");

                StringBuilder globalFinal = new StringBuilder();
                // 常量
                if (this.globalFinalName != null && this.globalFinalName.length > 0) {
                    Arrays.stream(this.globalFinalName).filter(Objects::nonNull).distinct()
                        .forEach(item -> globalFinal.append(Character.LINE_BREAK + Character.LINE_BREAK + Character.TAB)
                            .append("private final String ").append(item).append(" = null;"));
                }

                // 固定方法
                fixedMehthodStr.append(globalFinal)
                    .append(Character.LINE_BREAK + Character.LINE_BREAK + Character.TAB).append("@Override")
                    .append(Character.LINE_BREAK + Character.TAB).append("protected String getApiUrl(String url) {")
                    .append(Character.LINE_BREAK + Character.TAB + Character.TAB).append("return getLocalhostUrl(").append(controller.getSimpleName()).append(".class) + url;")
                    .append(Character.LINE_BREAK + Character.TAB).append("}")
                    .append(Character.LINE_BREAK + Character.LINE_BREAK + Character.TAB).append("@Override")
                    .append(Character.LINE_BREAK + Character.TAB).append("public void execute() {")
                    .append(executeStr)
                    .append(Character.LINE_BREAK + Character.LINE_BREAK + Character.TAB + Character.TAB).append("// 打印测试信息")
                    .append(Character.LINE_BREAK + Character.TAB + Character.TAB).append("printInfo();")
                    .append(Character.LINE_BREAK + Character.TAB).append("}")
                    .append(Character.LINE_BREAK);

                // 拼接全部内容
                classStr.append(fixedMehthodStr.toString()).append(mehthodStr.toString()).append(Character.LINE_BREAK).append("}");
                write(address, testClassName + ".java", classStr.toString(), true);
                System.out.println("测试类生成完毕,存放路径:" + address + "/" + testClassName + ".java");
            });

    }

    private StringBuilder getImports(Set<String> imports) {
        StringBuilder builder = new StringBuilder();
        if (imports != null && imports.size() > 0) {
            imports.stream().sorted(Comparator.reverseOrder()).forEach(item -> builder.append(Character.LINE_BREAK).append("import ").append(item).append(";"));
        }
        return builder;
    }

    /**
     * 代码存储包
     */
    private final String codeStoragePack;

    /**
     * 项目包前缀
     */
    private final String projectPackPrefix;

    /**
     * 扫描包
     */
    private final String scanPack;

    /**
     * 文件分隔符
     */
    private final String fileSeparator = System.getProperty("file.separator");

    /**
     * 当前类地址
     */
    private String address = this.getClass().getResource(fileSeparator).getPath();

    /**
     * 代码存放路径
     */
    private String packageUrl;

    /**
     * 需要生成测试代码的controller
     */
    private Class<?>[] controllers;

    /**
     * 全局常量
     */
    private String[] globalFinalName = {};

    public GenerationTestCode(String codeStoragePack, String projectPackPrefix, String scanPack, Class<?>... controllers) {
        this.codeStoragePack = codeStoragePack;
        this.projectPackPrefix = projectPackPrefix;
        this.scanPack = scanPack;
        this.controllers = controllers;
        init();
    }

    public GenerationTestCode(String codeStoragePack, String projectPackPrefix, String scanPack, String[] globalFinalName, Class<?>... controllers) {
        this.codeStoragePack = codeStoragePack;
        this.projectPackPrefix = projectPackPrefix;
        this.scanPack = scanPack;
        this.controllers = controllers;
        this.globalFinalName = globalFinalName == null ? new String[]{} : globalFinalName;
        init();
    }

    public GenerationTestCode(String projectPackPrefix, String scanPack, String[] globalFinalName, Class<?>... controllers) {
        this.codeStoragePack = null;
        this.projectPackPrefix = projectPackPrefix;
        this.scanPack = scanPack;
        this.controllers = controllers;
        this.globalFinalName = globalFinalName == null ? new String[]{} : globalFinalName;
        init();
    }

    public GenerationTestCode(String projectPackPrefix, String scanPack, Class<?>... controllers) {
        this.codeStoragePack = null;
        this.projectPackPrefix = projectPackPrefix;
        this.scanPack = scanPack;
        this.controllers = controllers;
        init();
    }

    /**
     * Description: 初始
     *
     * @return void
     * @author sunlight
     */
    private void init() {
        if (this.controllers == null || this.controllers.length < 1) {
            checkNullException(scanPack, "扫描包路径为空且未指定controller");
        }
        checkNullException(address, "当前类地址获取失败");
        // 测试类target目录
        String testPath = "/target/test-classes/";
        if (address.lastIndexOf(testPath) == -1) {
            abnormalOutput("当前类只能放在测试文件夹");
        }
        // 计算生成代码实际包路径
        packageUrl = isNotBlankAndNull(codeStoragePack) ? codeStoragePack
            : this.getClass().getName().substring(0, this.getClass().getName().lastIndexOf(".")) + ".generate";
        // 计算当前类地址物理地址
        address = (address.substring(0, address.lastIndexOf(testPath)) + "/src/test/java/" + packageUrl)
            .replace(".", "/");
    }

    /**
     * Description: 异常输出
     *
     * @param errMsg 异常信息
     * @return void
     * @author sunlight
     */
    private void abnormalOutput(final String errMsg) {
        throw new IllegalArgumentException(this.getClass().getSimpleName() + " - error - " + errMsg);
    }

    /**
     * Description: 判断String为空
     *
     * @param str 源字符串
     * @return 为空返回true
     * @author sunlight
     */
    private boolean isBlankOrNull(final String str) {
        return str == null || "".equals(str.trim());
    }

    /**
     * Description: 检查Null异常
     *
     * @param str 要检查的字符
     * @param msg 错误信息
     * @return void
     * @author sunlight
     */
    private void checkNullException(final String str, final String msg) {
        if (isBlankOrNull(str)) {
            abnormalOutput(msg);
        }
    }

    /**
     * ClassName: character <br>
     * Description: 常用字符
     *
     * @author sunlight
     * @date 2020/12/11 14:43
     * @since JDK 1.8
     */
    private static class Character {
        /**
         * 标准空格[8个\r]
         */
        private static final String TAB = "\t";

        /**
         * 换行
         */
        private static final String LINE_BREAK = "\n";

        /**
         * 单行注释
         */
        private static final String SINGLE_LINE_COMMENTS = "// ";
    }

    /**
     * ClassName: ClassUtil <br>
     * Description: Class工具类
     *
     * @author sunlight
     * @date 2020/9/30 14:38
     * @since JDK 1.8
     */
    private static class ClassUtil {

        /**
         * Description: 获取class
         *
         * @param pack 扫描包路径 eg: com.sunlight.controller
         * @return 扫描到的class
         * @author sunlight
         */
        private static Set<Class<?>> getClasses(String pack) {
            // 第一个class类的集合
            Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
            // 是否循环迭代
            boolean recursive = true;
            // 获取包的名字 并进行替换
            String packageName = pack;
            // 包名换成实体路径
            String packageDirName = packageName.replace('.', '/');
            // 定义一个枚举的集合 并进行循环来处理这个目录下的东西
            Enumeration<URL> dirs;
            try {
                dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                // 循环迭代下去
                while (dirs.hasMoreElements()) {
                    // 获取下一个元素
                    URL url = dirs.nextElement();
                    // 得到协议的名称
                    String protocol = url.getProtocol();
                    // 如果是以文件的形式保存在服务器上
                    if ("file".equals(protocol)) {
                        // 获取包的物理路径
                        String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                        // 以文件的方式扫描整个包下的文件 并添加到集合中
                        findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                    } else if ("jar".equals(protocol)) {
                        // 如果是jar包文件
                        JarFile jar;
                        try {
                            // 获取jar
                            jar = ((JarURLConnection) url.openConnection()).getJarFile();
                            // 从此jar包 得到一个枚举类
                            Enumeration<JarEntry> entries = jar.entries();
                            // 同样的进行循环迭代
                            while (entries.hasMoreElements()) {
                                // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                                JarEntry entry = entries.nextElement();
                                String name = entry.getName();
                                // 如果是以/开头的
                                if (name.charAt(0) == '/') {
                                    // 获取后面的字符串
                                    name = name.substring(1);
                                }
                                // 如果前半部分和定义的包名相同
                                if (name.startsWith(packageDirName)) {
                                    int idx = name.lastIndexOf('/');
                                    // 如果以"/"结尾 是一个包
                                    if (idx != -1) {
                                        // 获取包名 把"/"替换成"."
                                        packageName = name.substring(0, idx).replace('/', '.');
                                    }
                                    // 如果可以迭代下去 并且是一个包
                                    // 如果是一个.class文件 而且不是目录
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        // 去掉后面的".class" 获取真正的类名
                                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                                        try {
                                            // 添加到classes
                                            classes.add(Class.forName(packageName + '.' + className));
                                        } catch (ClassNotFoundException e) {
                                            e.printStackTrace();
                                        }
                                    }
                                }
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            return classes;
        }

        /**
         * Description: 在文件包中查找并添加类
         *
         * @param packageName 包名
         * @param packagePath 物理路径
         * @param recursive   是否迭代
         * @param classes     扫描到的class集合
         * @return void
         * @author sunlight
         */
        private static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, Set<Class<?>> classes) {
            // 获取此包的目录 建立一个File
            File dir = new File(packagePath);
            // 如果不存在或者 也不是目录就直接返回
            if (!dir.exists() || !dir.isDirectory()) {
                return;
            }
            // 如果存在 就获取包下的所有文件 包括目录
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            File[] dirfiles = dir.listFiles(file -> (recursive && file.isDirectory()) || (file.getName().endsWith(".class")));
            if (dirfiles == null || dirfiles.length < 1) {
                return;
            }
            // 循环所有文件
            for (File file : dirfiles) {
                // 如果是目录 则继续扫描
                if (file.isDirectory()) {
                    findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
                } else {
                    // 如果是java类文件 去掉后面的.class 只留下类名
                    String className = file.getName().substring(0, file.getName().length() - 6);
                    try {
                        // 添加到集合中去
                        // 这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
                        classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
                    } catch (ClassNotFoundException e) {
                        // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
                        e.printStackTrace();
                    }
                }
            }
        }

        /**
         * Description: 获取对象所有set方法
         *
         * @param type            class
         * @param parameterName   参数名
         * @param indentation     缩进
         * @param globalFinalName 全局常量
         * @return java.lang.String 所有set方法
         * @author sunlight
         */
        private static String getAllSet(Class<?> type, String parameterName, String indentation, String... globalFinalName) {
            StringBuilder builder = new StringBuilder();
            Arrays.stream(type.getMethods()).sorted(Comparator.comparing(Method::getName)).forEach(method -> {
                String defaults = null;
                // 筛选 set方法
                if (method.getName().startsWith("set")) {
                    // 判断是否有入参,默认取第一入参
                    if (method.getParameters() != null && method.getParameters().length > 0) {
                        // 参数类型名称
                        final String parameterTypename = method.getParameterTypes()[0].getSimpleName();
                        // 判断类型
                        if ("String".equals(parameterTypename)) {
                            defaults = globalFinalName != null && Arrays.stream(globalFinalName).anyMatch(item -> method.getParameters()[0].getName().equals(item))
                                ? "this.".concat(method.getParameters()[0].getName())
                                : "\"" + method.getParameters()[0].getName() + "\"";
                        }
                        if ("byte".equals(parameterTypename) || "short".equals(parameterTypename)
                            || "int".equals(parameterTypename) || "long".equals(parameterTypename)
                            || "double".equals(parameterTypename) || "float".equals(parameterTypename)) {
                            defaults = "0";
                        }
                        if ("boolean".equals(parameterTypename) || "Boolean".equals(parameterTypename)) {
                            defaults = "false";
                        }
                    }
                    builder.append(indentation).append(parameterName).append(".").append(method.getName()).append("(").append(defaults).append(");");
                }
            });
            return builder.toString();
        }

        /**
         * Description: 获取方法出参List实际类型参数类型名称
         *
         * @param method 方法
         * @return 方法出参List实际类型参数类型名称
         * @author sunlight
         */
        private static String getListActualTypeArgumentsTypeName(Method method) {
            if (List.class.equals(method.getReturnType())) {
                return ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments()[0].getTypeName();
            }
            throw new IllegalArgumentException("方法出参不属于List类型");
        }

        /**
         * Description: 获取方法出参List实际类型参数类型名称简称
         *
         * @param method 方法
         * @return 方法出参List实际类型参数类型名称简称
         * @author sunlight
         */
        private static String getListActualTypeArgumentsSimpleTypeName(Method method) {
            if (null == method) {
                throw new IllegalArgumentException("method 为空");
            }
            String typeName = getListActualTypeArgumentsTypeName(method);
            return typeName.lastIndexOf(".") == -1 ? typeName : typeName.substring(typeName.lastIndexOf(".") + 1);
        }
    }

    /**
     * Description: 文件写入
     *
     * @param writePath    写入路径
     * @param writeName    写入文件名称
     * @param writeContent 写入内容
     * @param createFile   创建文件
     * @return 写入成功返回true
     * @author sunlight
     */
    public boolean write(String writePath, String writeName, String writeContent, boolean... createFile) {
        checkNullException(writePath, "文件写入失败,写入路径为空");
        checkNullException(writeName, "文件写入失败,写入文件名称为空");
        checkNullException(writeContent, "文件写入失败,写入内容为空");
        File file = new File(writePath + File.separator + writeName);
        // 创建文件夹
        if (createFile != null && createFile.length > 0 && createFile[0]) {
            File folder = new File(writePath);
            if (!folder.exists() && !folder.mkdirs()) {
                throw new IllegalArgumentException("文件夹:" + folder + "创建失败");
            }
        }
        if (createFile == null || createFile.length == 0 || !createFile[0]) {
            // 判断文件
            if (!file.exists()) {
                throw new IllegalArgumentException("文件写入失败,文件不存在[" + writePath + File.separator + writeName + "]");
            }
        }
        try {
            OutputStream outputStream = new FileOutputStream(file);
            outputStream.write(writeContent.getBytes());
            outputStream.flush();
            outputStream.close();
            return true;
        } catch (IOException e) {
            e.printStackTrace();
            throw new IllegalArgumentException("读取文件失败," + e.getMessage());
        }
    }

    /**
     * Description: 判断String不为空
     *
     * @param str 源字符串
     * @return 不为空返回true
     * @author sunlight
     */
    private static boolean isNotBlankAndNull(final String str) {
        return null != str && !"".equals(str.trim());
    }
}
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值