继续上篇内容,本篇讲解三部分内容,分别是拦截器,负载均衡,以及服务降级。
一. 拦截器
定义RpcFilter 接口,提供对请求对象和响应对象的处理
public interface RpcFilter {
Object filterRequest(Object params);
Object filterResponse(Object result);
/**
* 是否支持服务端过滤
* @return
*/
default boolean isProvider(){
return true;
}
/**
* 支付支持消费端过滤
* @return
*/
default boolean isConsumer(){
return true;
}
}
定义一个实现类,简单实现打印请求,响应数据的功能
@Component
@Order(Integer.MIN_VALUE)
public class LogFilter implements RpcFilter {
protected final Logger log = LoggerFactory.getLogger(getClass());
@Override
public Object filterRequest(Object params) {
log.info("请求参数:{}", JSON.toJSONString(params));
return params;
}
@Override
public Object filterResponse(Object result) {
log.info("返回参数:{}",JSON.toJSONString(result));
return result;
}
}
调整ProviderContextStart代码,向start方法中加入initFilters逻辑。
public class ProviderContextStart implements SmartLifecycle, ApplicationContextAware {
@Override
public void start() {
log.info("准备注册服务");
try {
// 获取绑定地址
String ip = InetAddress.getLocalHost().getHostAddress();
rpcConfig = applicationContext.getBean(RainRpcConfig.class);
if (StringUtils.isEmpty(rpcConfig.getApplicationName())) {
log.info("rpc未配置应用名");
return;
}
// 启动服务类
new Thread(new NettyServerTask(ip, rpcConfig.getPort())).start();
// 初始化提供者
Map<String, ProviderBean> beansOfType = applicationContext.getBeansOfType(ProviderBean.class);
ProviderContext.PROVIDER_MAP.putAll(beansOfType);
// 初始化过滤器
initFilters();
// 注册应用
namingService = NamingFactory.createNamingService(rpcConfig.getRegisterAddress());
namingService.registerInstance(rpcConfig.getApplicationName(), ip, rpcConfig.getPort(), rpcConfig.getCluster());
RUNNING = true;
} catch (UnknownHostException | NacosException e) {
log.error("注册服务异常", e);
}
log.info("注册服务成功");
}
/**
* 加载拦截器
*/
public void initFilters() {
Map<String, RpcFilter> filterMap = applicationContext.getBeansOfType(RpcFilter.class);
List<RpcFilter> filterList = filterMap.values().stream()
.filter(RpcFilter::isProvider)
.sorted()
.collect(Collectors.toList());
ProviderContext.FILTER_LIST.addAll(filterList);
}
新建一个ConsumerContextStart,初始化拦截器数据
public class ConsumerContextStart implements SmartLifecycle , ApplicationContextAware {
protected final Log logger = LogFactory.getLog(getClass());
private volatile boolean RUNNING=false;
private ApplicationContext applicationContext;
@Override
public void start() {
logger.info("客户端准备初始化");
initFilters();
RUNNING = true;
logger.info("客户端准备初始化成功");
}
/**
* 初始化拦截器
*/
public void initFilters() {
Map<String, RpcFilter> filterMap = applicationContext.getBeansOfType(RpcFilter.class);
List<RpcFilter> filterList = filterMap.values().stream()
.filter(RpcFilter::isConsumer)
.sorted()
.collect(Collectors.toList());
ConsumerContext.FILTER_LIST.addAll(filterList);
}
@Override
public void stop() {
RUNNING = false;
}
@Override
public boolean isRunning() {
return RUNNING;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}
}
调整RainRpcAutoConfiguration,注入ConsumerContextStart ,添加对拦截器的扫描路径
@Configuration
@ComponentScan("com.rain.filter")
public class RainRpcAutoConfiguration {
@Bean
public ProviderContextStart providerContextStart() {
return new ProviderContextStart();
}
@Bean
public ConsumerContextStart consumerContextStart() {
return new ConsumerContextStart();
}
}
最后在上下文对象中,添加相关逻辑
public class ProviderContext {
/**
* beanName和 bean的对应关系,beanName 为全路径类名
*/
public static final Map<String, ProviderBean> PROVIDER_MAP = new ConcurrentHashMap<>();
public static final List<RpcFilter> FILTER_LIST = new ArrayList<>();
public static Object exec(String className, String methodName, Object params) throws InvocationTargetException, IllegalAccessException {
ProviderBean providerBean = PROVIDER_MAP.get(className);
if (providerBean == null) {
throw new RuntimeException("服务提供者不存在,调用类:" + className + ",调用方法" + methodName);
}
// 前置处理
params = invokeBefore(params);
Method method = providerBean.getMethod(methodName);
// 无参数
if (method.getParameterTypes().length == 0) {
return method.invoke(providerBean.getRef());
}
// 单个参数
if (method.getParameterTypes().length == 1) {
return method.invoke(providerBean.getRef(), params);
}
Object[] paramsArr = (Object[]) params;
// 多个参数
return invokeAfter(method.invoke(providerBean.getRef(), paramsArr));
}
public static Object invokeBefore(Object params) {
for (RpcFilter filter : FILTER_LIST) {
params=filter.filterRequest(params);
}
return params;
}
public static Object invokeAfter(Object params) {
for (RpcFilter filter : FILTER_LIST) {
params=filter.filterResponse(params);
}
return params;
}
}
public class ConsumerContext {
public static final Map<Long, RainFuture> RESULT_MAP = new ConcurrentHashMap<>();
public static final Map<String, ProviderDirectory> CLIENT_MAP = new ConcurrentHashMap<>();
public static final List<RpcFilter> FILTER_LIST = new ArrayList<>();
public static Object invoke(Request request, String applicationName,Long requestTimeout) throws Exception {
// 前置处理
request.setData(invokeBefore(request.getData()));
RainFuture rainFuture = new RainFuture();
RESULT_MAP.put(request.getId(), rainFuture);
getChannelFuture(applicationName).channel().writeAndFlush(request);
Response response = rainFuture.get(requestTimeout, TimeUnit.SECONDS);
RESULT_MAP.remove(request.getId());
// 如果出现异常
if (ResultEnum.ERROR.getCode() == response.getCode()) {
throw response.getCause();
}
// 后置处理
return invokeAfter(response.getData());
}
public static ChannelFuture getChannelFuture(String applicationName) {
ProviderDirectory providerDirectory = CLIENT_MAP.get(applicationName);
NettyClientTask client = providerDirectory.getClient();
if (client == null) {
throw new RuntimeException("client 客户端不存在");
}
while (true) {
//获取future,线程有等待处理时间
if (null == client.channelFuture) {
try {
Thread.sleep(50);
} catch (InterruptedException e) {
e.printStackTrace();
}
continue;
}
return client.channelFuture;
}
}
public static Object invokeBefore(Object params) {
for (RpcFilter filter : FILTER_LIST) {
params = filter.filterRequest(params);
}
return params;
}
public static Object invokeAfter(Object params) {
for (RpcFilter filter : FILTER_LIST) {
params = filter.filterResponse(params);
}
return params;
}
}
二. 负载均衡
实现一个简单客户端负载功能,轮询可用的服务实例。
public class ProviderDirectory {
private final AtomicInteger seq=new AtomicInteger(0);
/**
* 获取客户端实例
* @return
*/
public NettyClientTask getClient() {
if (providerList.isEmpty()) {
return null;
}
List<ClientHolder> temp = this.providerList;
int index = seq.getAndIncrement();
return temp.get(index % temp.size()).getClient();
}
}
三. 服务降级
调整ConsumerMethodInterceptor 方法,添加降级处理,当@RainConsumer的degradation参数设置降级处理类,则执行降级逻辑。
public class ConsumerMethodInterceptor implements MethodInterceptor {
private RainConsumer annotation;
private Object degradationService;
private ApplicationContext applicationContext;
public ConsumerMethodInterceptor(RainConsumer annotation,ApplicationContext applicationContext) {
this.annotation = annotation;
this.applicationContext = applicationContext;
}
@Override
public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
Request request = Request.buildRequest(method.getDeclaringClass().getName(), method.getName(), objects);
try {
return ConsumerContext.invoke(request, annotation.name(),annotation.timeout());
} catch (Exception e) {
return degrade(method, objects, e);
}
}
/**
* 降级处理
* @return
*/
public Object degrade(Method method, Object[] objects,Exception e) throws Exception {
// 未设置降级处理
if (StringUtils.isEmpty(annotation.degradation())) {
throw e;
}
if (degradationService==null) {
Class<?> degradationClass = Class.forName(annotation.degradation());
degradationService = applicationContext.getBean(degradationClass);
}
return method.invoke(degradationService, objects);
}
}