JDK17、Junit5下的PowerMock平替

自制PowerMock功能介绍

在升级到JDK17以及Junit5之后,原来的powermock相关的静态方法和私有方法的API接口不能使用了,给迁移代码带来了极大的阻力。

其中影响最大的分别是静态公有方法的mock和私有方法的执行。于是自研了替代powerMock主要功能的代码。

原有的PowerMock的API的使用:

静态类mock:
	PowerMockito.mockStatic(RedisLock.class);
	PowerMockito.mockStatic(RedisClient.class,RedisLock.class,XXX.class...);
静态方法mock:
    Mockito.when(RedisClient.getJedis()).thenReturn(jedis);
私有静态不可变属性设置:
	Whitebox.setInternalState(RedisLock.class, "LOG", LOG);
私有方法唤醒:
    Whitebox.invokeMethod(RedisLock.class, "renewExpiration", "lockNo101", "lv101", 30000L);        
静态方法验证:
	PowerMockito.verifyStatic(RedisClient.class, atLeast(1));
	RedisClient.returnRedis(jedis);

本着最小程度改变之前用户习惯的原则,新建的类分别以powerMock原来的类PowerMockito和WhiteBox为名,接口也极力接近原来的定义。先展示下最终的成果:

静态类mock--不变
	PowerMockito.mockStatic(RedisLock.class);
	PowerMockito.mockStatic(RedisClient.class,RedisLock.class,XXX.class...);
静态方法mock--不变
    Mockito.when(RedisClient.getJedis()).thenReturn(jedis);
私有静态不可变属性设置--不变
	Whitebox.setInternalState(RedisLock.class, "LOG", LOG);
私有方法唤醒--不变
    Whitebox.invokeMethod(RedisLock.class, "renewExpiration", "lockNo101", "lv101", 30000L);        
静态方法验证--合理改变
	PowerMockito.verifyStatic(RedisClient.class, atLeast(1));
	RedisClient.returnRedis(jedis);
改变后的写法:
PowerMockito.verifyStatic(RedisClient.class, atLeast(1),RedisClient.returnRedis(jedis));

可以看出,在重新修改后,原来代码需要修改的部分非常小。也许,看到这儿你可能会质疑,Junit5中Mockito不是自带了个static类的mock吗,何必麻烦呢?

此时,我们可以使用Mockito自带的mock静态类的方法来对比看看。依旧是以上诉几种情况的对比

静态类mock:
	PowerMockito.mockStatic(RedisLock.class);
	PowerMockito.mockStatic(RedisClient.class,RedisLock.class,XXX.class...);
    try (MockedStatic<RedisLock> mocked0 = mockStatic(RedisLock.class);){
        ...
    }
    try (MockedStatic<RedisLock> mocked0 = mockStatic(RedisLock.class);
        MockedStatic<RedisClient> mocked1 = mockStatic(RedisClient.class);
        ){
        ...
    }
静态方法mock:
    try (MockedStatic<RedisClient> mocked0 = mockStatic(RedisClient.class);){
        mocked0.when(()->RedisClient.getJedis()).thenReturn(jedis);
    }
私有静态不可变属性设置:
	Whitebox.setInternalState(RedisLock.class,RedisLock.class, "LOG", LOG);
私有方法唤醒:
    Whitebox.invokeMethod(RedisLock.class, "renewExpiration", "lockNo101", "lv101", 30000L);        
静态方法验证:
     try (MockedStatic<RedisClient> mocked0 = mockStatic(RedisClient.class);){
        mocked0.verify(()->RedisClient.getJedis(), atLeast(1));
    }

这儿区别已经不小了,如果觉得还不够大,让我们看看个微服务中的实例,下面是我截取了某个域内服务使用Mockito原生写的代码:

@ExtendWith(MockitoExtension.class)
public class DeliveryServiceTaskTest {
    MockedStatic<EnvVariable> aStatic1;
    MockedStatic<ConfigUtils> aStatic2;
    MockedStatic<LockUtil> aStatic3;
    MockedStatic<DRMClient> aStatic4;
    
    @BeforeEach
    public void init() throws Exception {
        DistributedResourceManager manager = Mockito.mock(DistributedResourceManager.class);
        aStatic4.when(()->DRMClient.getInstance()).thenReturn(manager);
        Whitebox.setInternalState(SpringContextHelper.class,SpringContextHelper.class,"context",context);
        mockSystemProperty();
    }
    
    @AfterEach
    public void after() {
        aStatic1.close();
        aStatic2.close();
        aStatic3.close();
    }
    
    @Test
    public void should_not_sendEvent_as_trylock_false() {
        DeliveryServiceTask deliveryTaskSpy = Mockito.spy(deliveryServiceTask);
        try (MockedStatic<LockUtil> mocked0 = mockStatic(LockUtil.class);
            MockedStatic<ConfigUtils> mocked1 = mockStatic(ConfigUtils.class)) {
            mocked0.when(() -> LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(true);
            Mockito.when(LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(false);

            deliveryTaskSpy.start();

            Mockito.verify(deliveryTaskSpy, never()).sendEvent(Mockito.any(), Mockito.any());
        }
    }

可以很容易看出以上代码不仅臃肿,而且可读性极差,原因在于具体到某个静态类的mock,需要给其费尽心思取个名,在测试结束后,还需要去将mock的资源全部释放掉。很容易想象到,一旦代码里面mock的静态类多了之后,如果采用预先mock,事后释放资源的方式,这样极容易忘记释放某个资源(比如例子中就少释放了一个资源),要么try-with-source写出难看的一大坨。本来测试第一要义就是可读性,结果还搞得如此难以描述。熟悉我们微服务的代码的人都知道,一个测试类里极有可能要mock十几个静态类,可以想象这种情况代码能有多糟糕。当然,这里还要额外批判一下,如果一个类需要mock十几个静态类,足以说明这个类写的有多差劲。

可以看看这个例子修改后的版本

@ExtendWith(MockitoExtension.class,MockitoStaticExtension.class)
public class DeliveryServiceTaskTest {
    @BeforeEach
    public void init() throws Exception {
        PowerMockito.mockStatic(EnvVariable.class,ConfigUtils.class,LockUtil.class,DRMClient.class)
        DistributedResourceManager manager = Mockito.mock(DistributedResourceManager.class);
        Mockito.when(DRMClient.getInstance()).thenReturn(manager);
        Whitebox.setInternalState(SpringContextHelper.class,"context",context);
    }  
    
    @Test
    public void should_not_sendEvent_as_trylock_false() {
        DeliveryServiceTask deliveryTaskSpy = Mockito.spy(deliveryServiceTask);
        Mockito.when(LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(true);
        Mockito.when(LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(false);
        
        deliveryTaskSpy.start();
        
        Mockito.verify(deliveryTaskSpy, never()).sendEvent(Mockito.any(), Mockito.any());
      }
}

使用方式及其简单,仅仅是在测试类上加上注解@ExtendWith(MockitoExtension.class,MockitoStaticExtension.class),然后照常写业务代码。

某个微服务中找到的多个静态类的mock:

    @Test
    public void test_gen_notify_req_should_return_not_null_when_test_data_combination() throws Exception {
        try (MockedStatic<SystemLangCache> aStatic = Mockito.mockStatic(SystemLangCache.class);
             MockedStatic<SpringContextUtil> aStatic1 = Mockito.mockStatic(SpringContextUtil.class);
             MockedStatic<ApplicationContext> aStatic2 = Mockito.mockStatic(ApplicationContext.class);
             MockedStatic<CommonLogicUtils> aStatic3 = Mockito.mockStatic(CommonLogicUtils.class);
             MockedStatic<DateConvertor> aStatic4 = Mockito.mockStatic(DateConvertor.class);
             MockedStatic<JacksonUtil> aStatic5 = Mockito.mockStatic(JacksonUtil.class);
             MockedStatic<SubTemplateInfoFactory> aStatic6 = Mockito.mockStatic(SubTemplateInfoFactory.class);
             MockedStatic<TemplateUtils> aStatic7 = Mockito.mockStatic(TemplateUtils.class)) {
            
            // setup
            Mockito.when(SystemLangCache.getDefaultLang()).thenReturn("string");
            Mockito.when(SpringContextUtil.getBean("drmReader")).thenReturn(drmReader);
            ...
        }
    }

修改后的版本:

    @Test
    public void test_gen_notify_req_should_return_not_null_when_test_data_combination() throws Exception {
        PowerMockito.mockStatic(SystemLangCache.class, SpringContextUtil.class, ApplicationContext.class,
            CommonLogicUtils.class, DateConvertor.class, JacksonUtil.class, SubTemplateInfoFactory.class,
            TemplateUtils.class);

        // setup
        Mockito.when(SystemLangCache.getDefaultLang()).thenReturn("string");
        Mockito.when(SpringContextUtil.getBean("drmReader")).thenReturn(drmReader);
        ...
    }

不必关注资源的开启和释放,也不必被try-with-source搞乱代码结构,更不必做mock命名

代码源码

以下是代码源码,希望批评指正。

1 PowerMockito

围绕MockStatic功能和VerifyStatic进行。

import static org.mockito.Mockito.times;

import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;

import java.util.HashMap;
import java.util.Map;

public class PowerMockito {
    private static final ThreadLocal<Map<Class<?>, MockedStatic<?>>> MOCKED_STATICS = new ThreadLocal<>();

    static {
        MOCKED_STATICS.set(new HashMap<>());
    }

    public static void mockStatic(Class<?> type, Class<?>... aClasses) {
        doMockStatic(type);
        if (aClasses != null) {
            for (Class<?> aClass : aClasses) {
                doMockStatic(aClass);
            }
        }
    }

    private static void doMockStatic(Class<?> type) {
        assertNotNull(type, "mock class can not be null");
        Map<Class<?>, MockedStatic<?>> classMockedStaticMap = MOCKED_STATICS.get();
        if (!classMockedStaticMap.containsKey(type)) {
            classMockedStaticMap.put(type, Mockito.mockStatic(type));
        }
    }

    public static void closeAllStatic() {
        for (Map.Entry<Class<?>, MockedStatic<?>> classMockedStaticEntry : MOCKED_STATICS.get().entrySet()) {
            classMockedStaticEntry.getValue().close();
        }
        MOCKED_STATICS.get().clear();
    }

    public static MockedStatic<?> get(Class<?> aClass) {
        if (!MOCKED_STATICS.get().containsKey(aClass)) {
            throw new IllegalStateException("Static mock for " + aClass.getName() + " is not found");
        }
        return MOCKED_STATICS.get().get(aClass);
    }

    public static boolean contains(Class<?> aClass) {
        return MOCKED_STATICS.get().containsKey(aClass);
    }

    public static void verifyStatic(Class<?> aClass, MockedStatic.Verification verification) {
        verifyStatic(aClass, times(1), verification);
    }

    public static void verifyStatic(Class<?> aClass, VerificationMode verificationMode,
        MockedStatic.Verification verification) {
        assertNotNull(aClass, "Static mock class can not be null");
        assertNotNull(verificationMode, "Static mock verificationMode can not be null");
        assertNotNull(verification, "Static mock verification can not be null");
        if (!MOCKED_STATICS.get().containsKey(aClass)) {
            throw new IllegalStateException("Static mock for " + aClass.getName() + " is not found");
        }
        MOCKED_STATICS.get().get(aClass).verify(verification, verificationMode);
    }

    public static void closeStatic(Class<?> type, Class<?>... aClasses) {
        closeStatic(type);
        if (aClasses != null) {
            for (Class<?> aClass : aClasses) {
                closeStatic(type);
            }
        }
    }

    private static void closeStatic(Class<?> type) {
        Map<Class<?>, MockedStatic<?>> classMockedStaticMap = MOCKED_STATICS.get();
        if (classMockedStaticMap.containsKey(type)) {
            classMockedStaticMap.get(type).close();
            classMockedStaticMap.remove(type);
        }
    }

    private static void assertNotNull(Object object, String msg) {
        if (object == null) {
            throw new IllegalArgumentException(msg);
        }
    }
}

2 PowerMockito的Whitebox功能

提供原来框架的setInternalState、getInternalState、getField、invokeMethod功能,API接口一致。由于java17的反射不再允许改变静态final字段的值,所以这方面就花费了不小心思。

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Map;

/**
 * 功能描述
 *
 * @since 2024-09-12
 */
public class Whitebox {
    public static final Logger LOG = LoggerFactory.getLogger(Whitebox.class);

    private static final Map<Class<?>, Class<?>> PRIMITIVE_TYPE_WRAPPER =
        Map.of(int.class, Integer.class, long.class, Long.class, boolean.class, Boolean.class, byte.class, Byte.class,
            char.class, Character.class, double.class, Double.class, float.class, Float.class);

    public static void setInternalState(Object object, String fieldName, Object value) {
        if (object == null) {
            throw new IllegalArgumentException("The instance containing the field cannot be null");
        }
        Class<?> aClass = object instanceof Class<?> ? (Class<?>) object : object.getClass();
        try {
            Field field = getField(aClass, fieldName);
            removeFinalModifierIfPresent(field);
            field.set(object, value);
        } catch (NoSuchFieldException | IllegalAccessException e) {
            throw new RuntimeException("Internal error: Failed to set field in method setInternalState.", e);
        }
    }

    public static Field getField(Class<?> declaringClass, String fieldName) throws NoSuchFieldException {
        Field field = getDeclaredField(declaringClass, fieldName);
        field.setAccessible(true);
        return field;
    }

    public static <T> T getInternalState(Object object, String fieldName) {
        try {
            Field foundField = getField(getClassType(object), fieldName);
            return (T) foundField.get(object);
        } catch (IllegalAccessException | NoSuchFieldException e) {
            throw new RuntimeException("Internal error: Failed to get field in method getInternalState.", e);
        }
    }

    public static <T> T invokeMethod(Object tested, String methodToExecute) throws Exception {
        Class<?> testedType = getClassType(tested);
        Method declaredMethod = getDeclaredMethod(testedType, methodToExecute, null);
        return doInvokeMethod(tested, declaredMethod, methodToExecute, null);
    }

    public static <T> T invokeMethod(Object tested, String methodToExecute, Object... arguments) throws Exception {
        Class<?> testedType = getClassType(tested);
        if (arguments == null) {
            arguments = new Object[1];
            arguments[0] = null;
        }
        Class<?>[] parameterTypes = getParameterClasses(arguments);
        Method declaredMethod = getDeclaredMethod(testedType, methodToExecute, parameterTypes);
        return doInvokeMethod(tested, declaredMethod, methodToExecute, arguments);
    }

    private static <T> T doInvokeMethod(Object tested, Method declaredMethod, String methodToExecute,
        Object... arguments) throws Exception {
        if (declaredMethod == null) {
            LOG.error("can not find method {}", methodToExecute);
            return null;
        }

        declaredMethod.setAccessible(true);
        try {
            return (T) declaredMethod.invoke(tested, arguments);
        } catch (InvocationTargetException e) {
            throw (Exception) e.getTargetException();
        }
    }

    private static Class<?>[] getParameterClasses(Object[] arguments) {
        if (arguments == null) {
            return null;
        }
        Class<?>[] parameterTypes = new Class[arguments.length];
        for (int i = 0; i < arguments.length; i++) {
            if (arguments[i] == null) {
                parameterTypes[i] = null;
                continue;
            }
            parameterTypes[i] = getClassType(arguments[i]);
        }
        return parameterTypes;
    }

    private static void removeFinalModifierIfPresent(Field field) throws IllegalAccessException {
        int fieldModifiersMask = field.getModifiers();
        if ((fieldModifiersMask & Modifier.FINAL) == Modifier.FINAL) {
            setModifiersToField(field);
        }
    }

    private static void setModifiersToField(Field field) {
        try {
            Method getDeclaredFields0 = Class.class.getDeclaredMethod("getDeclaredFields0", boolean.class);
            getDeclaredFields0.setAccessible(true);
            Field[] fields = (Field[]) getDeclaredFields0.invoke(Field.class, false);
            Field modifiersField = null;
            for (Field each : fields) {
                if ("modifiers".equals(each.getName())) {
                    modifiersField = each;
                    break;
                }
            }
            modifiersField.setAccessible(true);
            modifiersField.setInt(field, field.getModifiers() & ~Modifier.FINAL);
        } catch (InvocationTargetException | NoSuchMethodException | IllegalAccessException e) {
            throw new RuntimeException(e);
        }
    }

    private static Class<?> getClassType(Object object) {
        if (object == null) {
            throw new IllegalArgumentException("The object to perform the operation on cannot be null.");
        }
        return object instanceof Class<?> ? (Class<?>) object : object.getClass();
    }

    public static Method getDeclaredMethod(Class<?> aClass, String methodName, Class<?>... parameterTypes) {
        for (Class<?> clazz = aClass; clazz != Object.class; clazz = clazz.getSuperclass()) {
            try {
                for (Method declaredMethod : clazz.getDeclaredMethods()) {
                    if (declaredMethod.getName().equals(methodName)
                        && isParameterTypesRight(declaredMethod, parameterTypes)) {
                        return declaredMethod;
                    }
                }
            } catch (Exception e) {
                // Nothing Should do
            }
        }
        return null;
    }

    public static Field getDeclaredField(Class<?> aClass, String fieldName) {
        for (Class<?> clazz = aClass; clazz != Object.class; clazz = clazz.getSuperclass()) {
            try {
                return clazz.getDeclaredField(fieldName);
            } catch (Exception e) {
                // Nothing Should do
            }
        }
        return null;
    }

    private static boolean isParameterTypesRight(Method declaredMethod, Class<?>[] parameterTypes) {
        Class<?>[] types = declaredMethod.getParameterTypes();
        if (parameterTypes == null) {
            return types == null || types.length == 0;
        }

        if (types == null) {
            return parameterTypes == null || parameterTypes.length == 0;
        }

        if (parameterTypes.length != types.length) {
            return false;
        }

        for (int i = 0; i < types.length; i++) {
            if (parameterTypes[i] == null) {
                continue;
            }
            if (parameterTypes[i] != types[i] && !types[i].isAssignableFrom(parameterTypes[i])
                && !isWrapperOfPrimitive(types[i], parameterTypes[i])) {
                return false;
            }
        }

        return true;
    }

    private static boolean isWrapperOfPrimitive(Class<?> primitiveClass, Class<?> wrapperClass) {
        if (!PRIMITIVE_TYPE_WRAPPER.containsKey(primitiveClass)) {
            return false;
        }
        return PRIMITIVE_TYPE_WRAPPER.get(primitiveClass).equals(wrapperClass);
    }
}

3、MockedStaticExtension

import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;

/**
 * 功能描述
 *
 * @since 2024-10-14
 */
public class MockStaticExtension implements AfterEachCallback {

    @Override
    public void afterEach(ExtensionContext context) throws Exception {
        PowerMockito.closeAllStatic();
    }
}

最后说一句,其实写这个工具类我更多的是迫不得已。静态类大量出现在某个类中,说明要么有人滥用静态类,违背对象的封装,将本不该做静态类的类写成了静态类,要么是有人把一个类写的太大太杂,写成了一个大锅烩。“某有三尺长剑,可斩尽天下妖邪,惟愿永不出鞘”。希望大家都有良好的代码习惯,并互勉。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值