Discovery
package com.test.rpc.discovery;
public interface ServiceDiscovery {
String discover(String serviceName);
}
package com.test.rpc.discovery.impl;
import com.test.rpc.discovery.ServiceDiscovery;
import com.test.rpc.loadbalance.LoadBalance;
import com.test.rpc.loadbalance.impl.AbstractLoadBalance;
import com.test.rpc.registry.constants.ZKConfig;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.CuratorFrameworkFactory;
import org.apache.curator.framework.recipes.cache.PathChildrenCache;
import org.apache.curator.framework.recipes.cache.PathChildrenCacheEvent;
import org.apache.curator.framework.recipes.cache.PathChildrenCacheListener;
import org.apache.curator.retry.ExponentialBackoffRetry;
import java.util.List;
public class ServiceDiscoveryImpl implements ServiceDiscovery {
List<String> hostList = null;
private CuratorFramework curatorFramework;
{
curatorFramework = CuratorFrameworkFactory.builder()
.connectString(ZKConfig.ZK_CONNECTION_STRING)
.sessionTimeoutMs(4000)
.retryPolicy(new ExponentialBackoffRetry(1000, 3))
.build();
curatorFramework.start();
}
@Override
public String discover(String serviceName) {
String resultServerUrl = null;
String path = ZKConfig.ZK_REGISTER_PATH + "/" + serviceName;
try {
hostList = curatorFramework.getChildren().forPath(path);
} catch (Exception e) {
throw new RuntimeException("获取子节点异常;", e);
}
//动态发现服务节点的变化
registerWatcher(path);
//负载均衡机制(利用spi机制获取负载均衡具体实现类)
LoadBalance loadBalance = AbstractLoadBalance.getInstants();
resultServerUrl = loadBalance.selectHost(hostList);
System.out.printf("[%d]负载均衡:hostList={%s},selected={%s}\r\n",Thread.currentThread().getId(), hostList, resultServerUrl);
return resultServerUrl;
}
private void registerWatcher(final String path) {
PathChildrenCache childrenCache = new PathChildrenCache(curatorFramework, path, true);
PathChildrenCacheListener pathChildrenCacheListener = new PathChildrenCacheListener() {
@Override
public void childEvent(CuratorFramework client, PathChildrenCacheEvent event) throws Exception {
hostList = curatorFramework.getChildren().forPath(path);
}
};
childrenCache.getListenable().addListener(pathChildrenCacheListener);
try {
childrenCache.start();
} catch (Exception e) {
throw new RuntimeException("注册pathChildren wather异常", e);
}
}
}
proxy
package com.test.rpc.discovery.proxy;
public interface ProxyFactory {
<T> T getInstance(Class<T> clazz) throws Exception;
}
package com.test.rpc.discovery.proxy.impl;
import com.test.rpc.model.RpcRequest;
import com.test.rpc.protocol.Transport;
import com.test.rpc.protocol.impl.TcpTransportGood;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
public class RpcInvocationHandler<T> implements InvocationHandler {
private Class<T> clazz;
public RpcInvocationHandler(Class<T> clazz) throws Exception {
this.clazz = clazz;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
Object result = null;
try {
RpcRequest rpcRequest = new RpcRequest();
rpcRequest.setMethodName(method.getName());
rpcRequest.setParameterTypes(method.getParameterTypes());
rpcRequest.setArgs(args);
rpcRequest.setClassName(clazz.getName());
// Transport transport = new TcpTransportBad();
Transport transport = new TcpTransportGood("127.0.0.1",8081);
result = transport.send(rpcRequest);
} catch (Exception e) {
e.printStackTrace();
}
return result;
}
}
package com.test.rpc.discovery.proxy.impl;
import com.test.rpc.discovery.proxy.ProxyFactory;
import java.lang.reflect.Proxy;
public class JDKProxyFactory implements ProxyFactory {
@Override
public <T> T getInstance(Class<T> clazz) throws Exception {
return (T) Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[]{clazz}, new RpcInvocationHandler(clazz));
}
}
servicelocator
package com.test.rpc.discovery.servicelocator;
import java.io.Serializable;
import java.util.Objects;
public class ServiceKey implements Serializable{
private static final long serialVersionUID = -2995737588215700768L;
private String contract;
private String implCode;
public ServiceKey(String contract, String implCode) {
this.contract = contract;
this.implCode = implCode;
}
public static String getKeyStr(String contract, String implCode) {
return contract + "/" + ((null == implCode) ? "default" : implCode);
}
public String getContract() {
return contract;
}
public void setContract(String contract) {
this.contract = contract;
}
public String getImplCode() {
return implCode;
}
public void setImplCode(String implCode) {
this.implCode = implCode;
}
@Override
public String toString() {
return "ServiceKey{" +
"contract='" + contract + '\'' +
", implCode='" + implCode + '\'' +
'}';
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
ServiceKey that = (ServiceKey) o;
return Objects.equals(contract, that.contract) &&
Objects.equals(implCode, that.implCode);
}
@Override
public int hashCode() {
return Objects.hash(contract, implCode);
}
}
package com.test.rpc.discovery.servicelocator;
import com.test.rpc.discovery.ServiceDiscovery;
import com.test.rpc.discovery.impl.ServiceDiscoveryImpl;
import com.test.rpc.model.RpcRequest;
import com.test.rpc.protocol.Transport;
import com.test.rpc.protocol.impl.TcpTransportGood;
import com.test.rpc.utils.ClassUtils;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
public class ServiceLocator {
public static <T> T getService(String contract, String implCode, boolean lazy) throws ClassNotFoundException {
ServiceKey key = new ServiceKey(contract, implCode);
return (T) Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[] { ClassUtils.forName(contract) }, new ServiceInvocation(key));
}
private static class ServiceInvocation implements InvocationHandler {
private ServiceKey serviceKey;
public ServiceInvocation(ServiceKey serviceKey) {
this.serviceKey = serviceKey;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (Object.class == method.getDeclaringClass()) {
return method.invoke(args);
} else {
RpcRequest rpcRequest = new RpcRequest();
rpcRequest.setMethodName(method.getName());
rpcRequest.setParameterTypes(method.getParameterTypes());
rpcRequest.setArgs(args);
String serviceName = serviceKey.getContract() + "#" + serviceKey.getImplCode();
rpcRequest.setClassName(serviceName);
ServiceDiscovery discovery = new ServiceDiscoveryImpl();
String serviceAddress = discovery.discover(serviceName);
//若根据服务名称能够找到对应的机器则连接目标机器进行消息发送处理,否则提示失败信息,抛异常
if (serviceAddress != null) {
String[] strings = serviceAddress.split(":");
String hostName = strings[0];
int port = Integer.parseInt(strings[1]);
Transport transport = new TcpTransportGood(hostName, port);
return transport.send(rpcRequest);
} else {
System.out.printf("can not find the service by the key:%s\r\n", serviceName);
throw new Exception(String.format("can not find the service by the key:%s\r\n", serviceName));
}
}
}
}
}
publish
package com.test.rpc.publish;
import com.test.rpc.annotation.Implement;
import com.test.rpc.registry.Register;
import com.test.rpc.spring.AbstractApplicationContext;
import com.test.rpc.spring.AnnotationConfigApplicationContext;
import java.io.File;
public class PublisherWithZK {
public static AbstractApplicationContext applicationContext = new AnnotationConfigApplicationContext();
public PublisherWithZK(Register registerCenter, String serviceAddress) {
this.registerCenter = registerCenter;
this.serviceAddress = serviceAddress;
}
//注册中心(注册中心地址)
private Register registerCenter;
//服务发布地址(ip:port)
private String serviceAddress;
public void export(String scanPackage) throws IllegalAccessException, InstantiationException {
//1.本地暴露
localExport(scanPackage);
//2.远程暴露
remoteExport();
//3.启动监听
try {
//从服务地址信息中取出服务ip及端口号
String[] addrs = this.serviceAddress.split(":");
int port = Integer.valueOf(addrs[1]);
new ServerListener(port).start();
} catch (Exception e) {
e.printStackTrace();
}
}
private void localExport(String scanPackage) throws InstantiationException, IllegalAccessException {
//1.扫描
doScan(scanPackage);
//2.注册
doRegister();
}
private void doRegister() throws IllegalAccessException, InstantiationException {
if(0 < applicationContext.beanList.size()){
for (Class bean:applicationContext.beanList){
Implement implement =(Implement)bean.getAnnotation(Implement.class);
applicationContext.registerBean(implement.contract().getName()+"#"+implement.impCode(),bean.newInstance());
}
}
}
private void doScan(String scanPackage) {
String directFileUrl = Thread.currentThread().getContextClassLoader().getResource(scanPackage.replaceAll("[.]", "/")).getFile();
File file = new File(directFileUrl);
if(file.exists() && file.isDirectory()){
File[] files = file.listFiles();
for (File subFile:files){
if(subFile.isDirectory()){
String subPackage = scanPackage+"."+subFile.getName();
doScan(subPackage);
}else{
String className = scanPackage+"."+(subFile.getName().replaceAll(".class",""));
try {
Class clazz = Class.forName(className);
if(clazz.isAnnotationPresent(Implement.class)){
applicationContext.registerBean(clazz);
}
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
}
private void remoteExport() {
//从服务地址信息中取出服务ip及端口号
String[] addrs = this.serviceAddress.split(":");
String ip = addrs[0];
int port = Integer.valueOf(addrs[1]);
//将服务注册到注册中心上
for(String serviceName:applicationContext.getIocKeys()){
registerCenter.register(serviceName,serviceAddress);
System.out.printf("注册服务成功:%s->%s",serviceName,serviceAddress);
}
}
}
package com.test.rpc.publish;
import com.test.rpc.model.RpcRequest;
import com.test.rpc.utils.ReflectUtil;
import java.io.*;
import java.net.Socket;
public class SocketHandlerThread implements Runnable {
private Socket socket;
public SocketHandlerThread(Socket socket) {
this.socket = socket;
}
@Override
public void run() {
handleSocket(this.socket);
}
private void handleSocket(Socket socket) {
try (OutputStream outputStream = socket.getOutputStream();
ObjectOutputStream outputStream1 = new ObjectOutputStream(outputStream);
//1.取到客户端发过来信息(类名,方法名,方法参数类型,参数值)
InputStream inputStream = socket.getInputStream();
ObjectInputStream objectInputStream = new ObjectInputStream(inputStream);) {
RpcRequest rpcRequest = (RpcRequest) objectInputStream.readObject();
//2.反射调用
Object response = ReflectUtil.invoke(rpcRequest);
//3.返回结果
outputStream1.writeObject(response);
outputStream1.flush();
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
package com.test.rpc.publish;
import com.test.rpc.model.RpcRequest;
import com.test.rpc.utils.ReflectUtil;
import java.io.*;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class ServerListener {
private int port;
ServerSocket serverSocket;
ExecutorService executor = null;
public ServerListener(int port) throws Exception {
this.port = port;
serverSocket = new ServerSocket(port);
executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
System.out.println("server init success ,listen on port:" + port);
}
public void start() throws Exception {
do {
Socket socket = serverSocket.accept();
if (null != socket) {
// new Thread(new SocketHandlerThread(socket)).start();
executor.execute(new SocketHandlerThread(socket));
// handleSocket(socket);
}
} while (true);
}
private void handleSocket(Socket socket) {
try (OutputStream outputStream = socket.getOutputStream();
ObjectOutputStream outputStream1 = new ObjectOutputStream(outputStream);
//1.取到客户端发过来信息(类名,方法名,方法参数类型,参数值)
InputStream inputStream = socket.getInputStream();
ObjectInputStream objectInputStream = new ObjectInputStream(inputStream);) {
RpcRequest rpcRequest = (RpcRequest) objectInputStream.readObject();
//2.反射调用
Object response = ReflectUtil.invoke(rpcRequest);
//3.返回结果
outputStream1.writeObject(response);
outputStream1.flush();
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
package com.test.rpc.publish;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class Publisher {
//类名 --实现类 接口--实现类
public static Map<String,Object> beanFactory = new ConcurrentHashMap<>();
public static void export() throws Exception {
//服务发布
Class clazz = Class.forName("com.test.producer.service.impl.DemoServiceImpl");
beanFactory.put("com.test.api.service.DemoService",clazz.newInstance());
//启动监听
new ServerListener(8083).start();
}
}
util
package com.test.rpc.utils;
import java.lang.reflect.Array;
import java.util.*;
public class ClassUtils {
public static final String ARRAY_SUFFIX = "[]";
private static final String INTERNAL_ARRAY_PREFIX = "[L";
private static final Map primitiveWrapperTypeMap = new HashMap(8);
private static final Map primitiveTypeNameMap = new HashMap(16);
public ClassUtils() {
}
public static ClassLoader getDefaultClassLoader() {
ClassLoader cl = null;
try {
cl = Thread.currentThread().getContextClassLoader();
} catch (Throwable var2) {
;
}
if (cl == null) {
cl = ClassUtils.class.getClassLoader();
}
return cl;
}
public static ClassLoader overrideThreadContextClassLoader(ClassLoader classLoaderToUse) {
Thread currentThread = Thread.currentThread();
ClassLoader threadContextClassLoader = currentThread.getContextClassLoader();
if (classLoaderToUse != null && !classLoaderToUse.equals(threadContextClassLoader)) {
currentThread.setContextClassLoader(classLoaderToUse);
return threadContextClassLoader;
} else {
return null;
}
}
public static Class forName(String name) throws ClassNotFoundException, LinkageError {
return forName(name, getDefaultClassLoader());
}
public static Class forName(String name, ClassLoader classLoader) throws ClassNotFoundException, LinkageError {
Class clazz = resolvePrimitiveClassName(name);
if (clazz != null) {
return clazz;
} else if (name.endsWith("[]")) {
String elementClassName = name.substring(0, name.length() - "[]".length());
Class elementClass = forName(elementClassName, classLoader);
return Array.newInstance(elementClass, 0).getClass();
} else {
int internalArrayMarker = name.indexOf("[L");
if (internalArrayMarker != -1 && name.endsWith(";")) {
String elementClassName = null;
if (internalArrayMarker == 0) {
elementClassName = name.substring("[L".length(), name.length() - 1);
} else if (name.startsWith("[")) {
elementClassName = name.substring(1);
}
Class elementClass = forName(elementClassName, classLoader);
return Array.newInstance(elementClass, 0).getClass();
} else {
ClassLoader classLoaderToUse = classLoader;
if (classLoader == null) {
classLoaderToUse = getDefaultClassLoader();
}
return classLoaderToUse.loadClass(name);
}
}
}
public static Class resolveClassName(String className, ClassLoader classLoader) throws IllegalArgumentException {
IllegalArgumentException iae;
try {
return forName(className, classLoader);
} catch (ClassNotFoundException var4) {
iae = new IllegalArgumentException("Cannot find class [" + className + "]");
iae.initCause(var4);
throw iae;
} catch (LinkageError var5) {
iae = new IllegalArgumentException("Error loading class [" + className + "]: problem with class file or dependent class.");
iae.initCause(var5);
throw iae;
}
}
public static Class resolvePrimitiveClassName(String name) {
Class result = null;
if (name != null && name.length() <= 8) {
result = (Class)primitiveTypeNameMap.get(name);
}
return result;
}
public static Class[] getAllInterfaces(Object instance) {
return getAllInterfacesForClass(instance.getClass());
}
public static Class<?>[] getAllInterfacesForClass(Class<?> clazz) {
return getAllInterfacesForClass(clazz, (ClassLoader)null);
}
public static Class<?>[] getAllInterfacesForClass(Class<?> clazz, ClassLoader classLoader) {
Set<Class> ifcs = getAllInterfacesForClassAsSet(clazz, classLoader);
return (Class[])ifcs.toArray(new Class[ifcs.size()]);
}
public static Set<Class> getAllInterfacesAsSet(Object instance) {
return getAllInterfacesForClassAsSet(instance.getClass());
}
public static Set<Class> getAllInterfacesForClassAsSet(Class clazz) {
return getAllInterfacesForClassAsSet(clazz, (ClassLoader)null);
}
public static Set<Class> getAllInterfacesForClassAsSet(Class clazz, ClassLoader classLoader) {
if (clazz.isInterface() && isVisible(clazz, classLoader)) {
return Collections.singleton(clazz);
} else {
LinkedHashSet interfaces;
for(interfaces = new LinkedHashSet(); clazz != null; clazz = clazz.getSuperclass()) {
Class<?>[] ifcs = clazz.getInterfaces();
Class[] arr$ = ifcs;
int len$ = ifcs.length;
for(int i$ = 0; i$ < len$; ++i$) {
Class<?> ifc = arr$[i$];
interfaces.addAll(getAllInterfacesForClassAsSet(ifc, classLoader));
}
}
return interfaces;
}
}
public static boolean isVisible(Class<?> clazz, ClassLoader classLoader) {
if (classLoader == null) {
return true;
} else {
try {
Class<?> actualClass = classLoader.loadClass(clazz.getName());
return clazz == actualClass;
} catch (ClassNotFoundException var3) {
return false;
}
}
}
static {
primitiveWrapperTypeMap.put(Boolean.class, Boolean.TYPE);
primitiveWrapperTypeMap.put(Byte.class, Byte.TYPE);
primitiveWrapperTypeMap.put(Character.class, Character.TYPE);
primitiveWrapperTypeMap.put(Double.class, Double.TYPE);
primitiveWrapperTypeMap.put(Float.class, Float.TYPE);
primitiveWrapperTypeMap.put(Integer.class, Integer.TYPE);
primitiveWrapperTypeMap.put(Long.class, Long.TYPE);
primitiveWrapperTypeMap.put(Short.class, Short.TYPE);
Set primitiveTypeNames = new HashSet(16);
primitiveTypeNames.addAll(primitiveWrapperTypeMap.values());
primitiveTypeNames.addAll(
Arrays.asList(boolean[].class, byte[].class, char[].class, double[].class, float[].class, int[].class, long[].class, short[].class));
Iterator it = primitiveTypeNames.iterator();
while(it.hasNext()) {
Class primitiveClass = (Class)it.next();
primitiveTypeNameMap.put(primitiveClass.getName(), primitiveClass);
}
}
}
package com.test.rpc.utils;
import com.test.rpc.model.RpcRequest;
import com.test.rpc.publish.PublisherWithZK;
import java.lang.reflect.Method;
public class ReflectUtil {
public static Object invoke(RpcRequest rpcRequest){
try {
//Object object = Publisher.beanFactory.get(rpcRequest.getClassName());
Object object = PublisherWithZK.applicationContext.getBean(rpcRequest.getClassName());
Class clazz = object.getClass();
Method method = clazz.getDeclaredMethod(rpcRequest.getMethodName(), rpcRequest.getParameterTypes());
Object result = method.invoke(object, rpcRequest.getArgs());
return result;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
}