自制PowerMock功能介绍
在升级到JDK17以及Junit5之后,原来的powermock相关的静态方法和私有方法的API接口不能使用了,给迁移代码带来了极大的阻力。
其中影响最大的分别是静态公有方法的mock和私有方法的执行。于是自研了替代powerMock主要功能的代码。
原有的PowerMock的API的使用:
静态类mock:
PowerMockito.mockStatic(RedisLock.class);
PowerMockito.mockStatic(RedisClient.class,RedisLock.class,XXX.class...);
静态方法mock:
Mockito.when(RedisClient.getJedis()).thenReturn(jedis);
私有静态不可变属性设置:
Whitebox.setInternalState(RedisLock.class, "LOG", LOG);
私有方法唤醒:
Whitebox.invokeMethod(RedisLock.class, "renewExpiration", "lockNo101", "lv101", 30000L);
静态方法验证:
PowerMockito.verifyStatic(RedisClient.class, atLeast(1));
RedisClient.returnRedis(jedis);
本着最小程度改变之前用户习惯的原则,新建的类分别以powerMock原来的类PowerMockito和WhiteBox为名,接口也极力接近原来的定义。先展示下最终的成果:
静态类mock--不变
PowerMockito.mockStatic(RedisLock.class);
PowerMockito.mockStatic(RedisClient.class,RedisLock.class,XXX.class...);
静态方法mock--不变
Mockito.when(RedisClient.getJedis()).thenReturn(jedis);
私有静态不可变属性设置--不变
Whitebox.setInternalState(RedisLock.class, "LOG", LOG);
私有方法唤醒--不变
Whitebox.invokeMethod(RedisLock.class, "renewExpiration", "lockNo101", "lv101", 30000L);
静态方法验证--合理改变
PowerMockito.verifyStatic(RedisClient.class, atLeast(1));
RedisClient.returnRedis(jedis);
改变后的写法:
PowerMockito.verifyStatic(RedisClient.class, atLeast(1),RedisClient.returnRedis(jedis));
可以看出,在重新修改后,原来代码需要修改的部分非常小。也许,看到这儿你可能会质疑,Junit5中Mockito不是自带了个static类的mock吗,何必麻烦呢?
此时,我们可以使用Mockito自带的mock静态类的方法来对比看看。依旧是以上诉几种情况的对比
静态类mock:
PowerMockito.mockStatic(RedisLock.class);
PowerMockito.mockStatic(RedisClient.class,RedisLock.class,XXX.class...);
try (MockedStatic<RedisLock> mocked0 = mockStatic(RedisLock.class);){
...
}
try (MockedStatic<RedisLock> mocked0 = mockStatic(RedisLock.class);
MockedStatic<RedisClient> mocked1 = mockStatic(RedisClient.class);
){
...
}
静态方法mock:
try (MockedStatic<RedisClient> mocked0 = mockStatic(RedisClient.class);){
mocked0.when(()->RedisClient.getJedis()).thenReturn(jedis);
}
私有静态不可变属性设置:
Whitebox.setInternalState(RedisLock.class,RedisLock.class, "LOG", LOG);
私有方法唤醒:
Whitebox.invokeMethod(RedisLock.class, "renewExpiration", "lockNo101", "lv101", 30000L);
静态方法验证:
try (MockedStatic<RedisClient> mocked0 = mockStatic(RedisClient.class);){
mocked0.verify(()->RedisClient.getJedis(), atLeast(1));
}
这儿区别已经不小了,如果觉得还不够大,让我们看看个微服务中的实例,下面是我截取了某个域内服务使用Mockito原生写的代码:
@ExtendWith(MockitoExtension.class)
public class DeliveryServiceTaskTest {
MockedStatic<EnvVariable> aStatic1;
MockedStatic<ConfigUtils> aStatic2;
MockedStatic<LockUtil> aStatic3;
MockedStatic<DRMClient> aStatic4;
@BeforeEach
public void init() throws Exception {
DistributedResourceManager manager = Mockito.mock(DistributedResourceManager.class);
aStatic4.when(()->DRMClient.getInstance()).thenReturn(manager);
Whitebox.setInternalState(SpringContextHelper.class,SpringContextHelper.class,"context",context);
mockSystemProperty();
}
@AfterEach
public void after() {
aStatic1.close();
aStatic2.close();
aStatic3.close();
}
@Test
public void should_not_sendEvent_as_trylock_false() {
DeliveryServiceTask deliveryTaskSpy = Mockito.spy(deliveryServiceTask);
try (MockedStatic<LockUtil> mocked0 = mockStatic(LockUtil.class);
MockedStatic<ConfigUtils> mocked1 = mockStatic(ConfigUtils.class)) {
mocked0.when(() -> LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(true);
Mockito.when(LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(false);
deliveryTaskSpy.start();
Mockito.verify(deliveryTaskSpy, never()).sendEvent(Mockito.any(), Mockito.any());
}
}
可以很容易看出以上代码不仅臃肿,而且可读性极差,原因在于具体到某个静态类的mock,需要给其费尽心思取个名,在测试结束后,还需要去将mock的资源全部释放掉。很容易想象到,一旦代码里面mock的静态类多了之后,如果采用预先mock,事后释放资源的方式,这样极容易忘记释放某个资源(比如例子中就少释放了一个资源),要么try-with-source写出难看的一大坨。本来测试第一要义就是可读性,结果还搞得如此难以描述。熟悉我们微服务的代码的人都知道,一个测试类里极有可能要mock十几个静态类,可以想象这种情况代码能有多糟糕。当然,这里还要额外批判一下,如果一个类需要mock十几个静态类,足以说明这个类写的有多差劲。
可以看看这个例子修改后的版本
@ExtendWith(MockitoExtension.class,MockitoStaticExtension.class)
public class DeliveryServiceTaskTest {
@BeforeEach
public void init() throws Exception {
PowerMockito.mockStatic(EnvVariable.class,ConfigUtils.class,LockUtil.class,DRMClient.class)
DistributedResourceManager manager = Mockito.mock(DistributedResourceManager.class);
Mockito.when(DRMClient.getInstance()).thenReturn(manager);
Whitebox.setInternalState(SpringContextHelper.class,"context",context);
}
@Test
public void should_not_sendEvent_as_trylock_false() {
DeliveryServiceTask deliveryTaskSpy = Mockito.spy(deliveryServiceTask);
Mockito.when(LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(true);
Mockito.when(LockUtil.tryLock(Mockito.any(ElasticLockWarp.class))).thenReturn(false);
deliveryTaskSpy.start();
Mockito.verify(deliveryTaskSpy, never()).sendEvent(Mockito.any(), Mockito.any());
}
}
使用方式及其简单,仅仅是在测试类上加上注解@ExtendWith(MockitoExtension.class,MockitoStaticExtension.class)
,然后照常写业务代码。
某个微服务中找到的多个静态类的mock:
@Test
public void test_gen_notify_req_should_return_not_null_when_test_data_combination() throws Exception {
try (MockedStatic<SystemLangCache> aStatic = Mockito.mockStatic(SystemLangCache.class);
MockedStatic<SpringContextUtil> aStatic1 = Mockito.mockStatic(SpringContextUtil.class);
MockedStatic<ApplicationContext> aStatic2 = Mockito.mockStatic(ApplicationContext.class);
MockedStatic<CommonLogicUtils> aStatic3 = Mockito.mockStatic(CommonLogicUtils.class);
MockedStatic<DateConvertor> aStatic4 = Mockito.mockStatic(DateConvertor.class);
MockedStatic<JacksonUtil> aStatic5 = Mockito.mockStatic(JacksonUtil.class);
MockedStatic<SubTemplateInfoFactory> aStatic6 = Mockito.mockStatic(SubTemplateInfoFactory.class);
MockedStatic<TemplateUtils> aStatic7 = Mockito.mockStatic(TemplateUtils.class)) {
// setup
Mockito.when(SystemLangCache.getDefaultLang()).thenReturn("string");
Mockito.when(SpringContextUtil.getBean("drmReader")).thenReturn(drmReader);
...
}
}
修改后的版本:
@Test
public void test_gen_notify_req_should_return_not_null_when_test_data_combination() throws Exception {
PowerMockito.mockStatic(SystemLangCache.class, SpringContextUtil.class, ApplicationContext.class,
CommonLogicUtils.class, DateConvertor.class, JacksonUtil.class, SubTemplateInfoFactory.class,
TemplateUtils.class);
// setup
Mockito.when(SystemLangCache.getDefaultLang()).thenReturn("string");
Mockito.when(SpringContextUtil.getBean("drmReader")).thenReturn(drmReader);
...
}
不必关注资源的开启和释放,也不必被try-with-source搞乱代码结构,更不必做mock命名。
代码源码
以下是代码源码,希望批评指正。
1 PowerMockito
围绕MockStatic功能和VerifyStatic进行。
import static org.mockito.Mockito.times;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;
import java.util.HashMap;
import java.util.Map;
public class PowerMockito {
private static final ThreadLocal<Map<Class<?>, MockedStatic<?>>> MOCKED_STATICS = new ThreadLocal<>();
static {
MOCKED_STATICS.set(new HashMap<>());
}
public static void mockStatic(Class<?> type, Class<?>... aClasses) {
doMockStatic(type);
if (aClasses != null) {
for (Class<?> aClass : aClasses) {
doMockStatic(aClass);
}
}
}
private static void doMockStatic(Class<?> type) {
assertNotNull(type, "mock class can not be null");
Map<Class<?>, MockedStatic<?>> classMockedStaticMap = MOCKED_STATICS.get();
if (!classMockedStaticMap.containsKey(type)) {
classMockedStaticMap.put(type, Mockito.mockStatic(type));
}
}
public static void closeAllStatic() {
for (Map.Entry<Class<?>, MockedStatic<?>> classMockedStaticEntry : MOCKED_STATICS.get().entrySet()) {
classMockedStaticEntry.getValue().close();
}
MOCKED_STATICS.get().clear();
}
public static MockedStatic<?> get(Class<?> aClass) {
if (!MOCKED_STATICS.get().containsKey(aClass)) {
throw new IllegalStateException("Static mock for " + aClass.getName() + " is not found");
}
return MOCKED_STATICS.get().get(aClass);
}
public static boolean contains(Class<?> aClass) {
return MOCKED_STATICS.get().containsKey(aClass);
}
public static void verifyStatic(Class<?> aClass, MockedStatic.Verification verification) {
verifyStatic(aClass, times(1), verification);
}
public static void verifyStatic(Class<?> aClass, VerificationMode verificationMode,
MockedStatic.Verification verification) {
assertNotNull(aClass, "Static mock class can not be null");
assertNotNull(verificationMode, "Static mock verificationMode can not be null");
assertNotNull(verification, "Static mock verification can not be null");
if (!MOCKED_STATICS.get().containsKey(aClass)) {
throw new IllegalStateException("Static mock for " + aClass.getName() + " is not found");
}
MOCKED_STATICS.get().get(aClass).verify(verification, verificationMode);
}
public static void closeStatic(Class<?> type, Class<?>... aClasses) {
closeStatic(type);
if (aClasses != null) {
for (Class<?> aClass : aClasses) {
closeStatic(type);
}
}
}
private static void closeStatic(Class<?> type) {
Map<Class<?>, MockedStatic<?>> classMockedStaticMap = MOCKED_STATICS.get();
if (classMockedStaticMap.containsKey(type)) {
classMockedStaticMap.get(type).close();
classMockedStaticMap.remove(type);
}
}
private static void assertNotNull(Object object, String msg) {
if (object == null) {
throw new IllegalArgumentException(msg);
}
}
}
2 PowerMockito的Whitebox功能
提供原来框架的setInternalState、getInternalState、getField、invokeMethod功能,API接口一致。由于java17的反射不再允许改变静态final字段的值,所以这方面就花费了不小心思。
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Map;
/**
* 功能描述
*
* @since 2024-09-12
*/
public class Whitebox {
public static final Logger LOG = LoggerFactory.getLogger(Whitebox.class);
private static final Map<Class<?>, Class<?>> PRIMITIVE_TYPE_WRAPPER =
Map.of(int.class, Integer.class, long.class, Long.class, boolean.class, Boolean.class, byte.class, Byte.class,
char.class, Character.class, double.class, Double.class, float.class, Float.class);
public static void setInternalState(Object object, String fieldName, Object value) {
if (object == null) {
throw new IllegalArgumentException("The instance containing the field cannot be null");
}
Class<?> aClass = object instanceof Class<?> ? (Class<?>) object : object.getClass();
try {
Field field = getField(aClass, fieldName);
removeFinalModifierIfPresent(field);
field.set(object, value);
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException("Internal error: Failed to set field in method setInternalState.", e);
}
}
public static Field getField(Class<?> declaringClass, String fieldName) throws NoSuchFieldException {
Field field = getDeclaredField(declaringClass, fieldName);
field.setAccessible(true);
return field;
}
public static <T> T getInternalState(Object object, String fieldName) {
try {
Field foundField = getField(getClassType(object), fieldName);
return (T) foundField.get(object);
} catch (IllegalAccessException | NoSuchFieldException e) {
throw new RuntimeException("Internal error: Failed to get field in method getInternalState.", e);
}
}
public static <T> T invokeMethod(Object tested, String methodToExecute) throws Exception {
Class<?> testedType = getClassType(tested);
Method declaredMethod = getDeclaredMethod(testedType, methodToExecute, null);
return doInvokeMethod(tested, declaredMethod, methodToExecute, null);
}
public static <T> T invokeMethod(Object tested, String methodToExecute, Object... arguments) throws Exception {
Class<?> testedType = getClassType(tested);
if (arguments == null) {
arguments = new Object[1];
arguments[0] = null;
}
Class<?>[] parameterTypes = getParameterClasses(arguments);
Method declaredMethod = getDeclaredMethod(testedType, methodToExecute, parameterTypes);
return doInvokeMethod(tested, declaredMethod, methodToExecute, arguments);
}
private static <T> T doInvokeMethod(Object tested, Method declaredMethod, String methodToExecute,
Object... arguments) throws Exception {
if (declaredMethod == null) {
LOG.error("can not find method {}", methodToExecute);
return null;
}
declaredMethod.setAccessible(true);
try {
return (T) declaredMethod.invoke(tested, arguments);
} catch (InvocationTargetException e) {
throw (Exception) e.getTargetException();
}
}
private static Class<?>[] getParameterClasses(Object[] arguments) {
if (arguments == null) {
return null;
}
Class<?>[] parameterTypes = new Class[arguments.length];
for (int i = 0; i < arguments.length; i++) {
if (arguments[i] == null) {
parameterTypes[i] = null;
continue;
}
parameterTypes[i] = getClassType(arguments[i]);
}
return parameterTypes;
}
private static void removeFinalModifierIfPresent(Field field) throws IllegalAccessException {
int fieldModifiersMask = field.getModifiers();
if ((fieldModifiersMask & Modifier.FINAL) == Modifier.FINAL) {
setModifiersToField(field);
}
}
private static void setModifiersToField(Field field) {
try {
Method getDeclaredFields0 = Class.class.getDeclaredMethod("getDeclaredFields0", boolean.class);
getDeclaredFields0.setAccessible(true);
Field[] fields = (Field[]) getDeclaredFields0.invoke(Field.class, false);
Field modifiersField = null;
for (Field each : fields) {
if ("modifiers".equals(each.getName())) {
modifiersField = each;
break;
}
}
modifiersField.setAccessible(true);
modifiersField.setInt(field, field.getModifiers() & ~Modifier.FINAL);
} catch (InvocationTargetException | NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
private static Class<?> getClassType(Object object) {
if (object == null) {
throw new IllegalArgumentException("The object to perform the operation on cannot be null.");
}
return object instanceof Class<?> ? (Class<?>) object : object.getClass();
}
public static Method getDeclaredMethod(Class<?> aClass, String methodName, Class<?>... parameterTypes) {
for (Class<?> clazz = aClass; clazz != Object.class; clazz = clazz.getSuperclass()) {
try {
for (Method declaredMethod : clazz.getDeclaredMethods()) {
if (declaredMethod.getName().equals(methodName)
&& isParameterTypesRight(declaredMethod, parameterTypes)) {
return declaredMethod;
}
}
} catch (Exception e) {
// Nothing Should do
}
}
return null;
}
public static Field getDeclaredField(Class<?> aClass, String fieldName) {
for (Class<?> clazz = aClass; clazz != Object.class; clazz = clazz.getSuperclass()) {
try {
return clazz.getDeclaredField(fieldName);
} catch (Exception e) {
// Nothing Should do
}
}
return null;
}
private static boolean isParameterTypesRight(Method declaredMethod, Class<?>[] parameterTypes) {
Class<?>[] types = declaredMethod.getParameterTypes();
if (parameterTypes == null) {
return types == null || types.length == 0;
}
if (types == null) {
return parameterTypes == null || parameterTypes.length == 0;
}
if (parameterTypes.length != types.length) {
return false;
}
for (int i = 0; i < types.length; i++) {
if (parameterTypes[i] == null) {
continue;
}
if (parameterTypes[i] != types[i] && !types[i].isAssignableFrom(parameterTypes[i])
&& !isWrapperOfPrimitive(types[i], parameterTypes[i])) {
return false;
}
}
return true;
}
private static boolean isWrapperOfPrimitive(Class<?> primitiveClass, Class<?> wrapperClass) {
if (!PRIMITIVE_TYPE_WRAPPER.containsKey(primitiveClass)) {
return false;
}
return PRIMITIVE_TYPE_WRAPPER.get(primitiveClass).equals(wrapperClass);
}
}
3、MockedStaticExtension
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
/**
* 功能描述
*
* @since 2024-10-14
*/
public class MockStaticExtension implements AfterEachCallback {
@Override
public void afterEach(ExtensionContext context) throws Exception {
PowerMockito.closeAllStatic();
}
}
最后说一句,其实写这个工具类我更多的是迫不得已。静态类大量出现在某个类中,说明要么有人滥用静态类,违背对象的封装,将本不该做静态类的类写成了静态类,要么是有人把一个类写的太大太杂,写成了一个大锅烩。“某有三尺长剑,可斩尽天下妖邪,惟愿永不出鞘”。希望大家都有良好的代码习惯,并互勉。