RPC是远程过程调用,对于java而言,就是两个JVM通信,一个JVM a想要调用另一个JVM b中的类。b把执行结果在发送给a的过程。好,我们就是要来实现这个过程。
两个接口:
public interface IDiff {
double diff(double a,double b);
}
public interface ISum {
public int sum(int a, int b);
}
两个实现:
public class DiffImpl implements IDiff{
@Override
public double diff(double a, double b) {
return a - b;
}
}
public class SumImpl implements ISum {
public int sum(int a, int b) {
return a + b;
}
}
我们假设这两个类在服务器上,客户端没有这两个类,但客户端还是想使用这两个服务,所以我们可以通过网络,把客户端的想法,告诉服务器,我要使用那两个类。而在这之前,服务器先要开启这个服务。
服务端:
public class RPCServer {
private static final Logger LOG = LoggerFactory.getLogger(RPCServer.class);
private static final int threadSize = 10;
private final ExecutorService threadPool;
/**
*服务在一开始就是确定的
*不允许客户端自动添加服务
*Key为全限定接口名,Value为接口实现类对象
*/
private final Map<String, Object> servicePool;
private final int port;
private volatile boolean stop;
public RPCServer(Map<String, Object> servicePool, int port) {
this.port = port;
this.threadPool = Executors.newFixedThreadPool(threadSize);
this.servicePool = servicePool;
}
/**
* RPC服务端处理函数 监听指定TPC端口
* 每次有请求过来的时候调用服务,放入线程池中处理.
*/
@SuppressWarnings("resource")
public void service() throws IOException {
ServerSocket serverSocket = new ServerSocket(port);
while (!stop) {
final Socket socket = serverSocket.accept();
threadPool.execute(new Runnable() {
public void run() {
try {
process(socket);
} catch (Exception e) {
LOG.warn("Illegal calls",e);
} finally {
IOUtil.close(socket);
}
}
});
}
}
public void stop() {
stop = true;
threadPool.shutdown();
}
/**
* 调用服务 通过TCP Socket返回结果对象
* 有可能因为客户端的
* 类名错误
* 方法名错误
* 参数错误
* 而调用失败
*/
private void process(Socket socket) throws Exception {
ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
Message message = (Message) in.readObject();
// 调用服务
Object result = call(message);
if(result == null){
LOG.warn("Without this service, the service interface is "+message.getInterfaceName());
IOUtil.close(socket);
return;
}
ObjectOutputStream out = new ObjectOutputStream(socket.getOutputStream());
out.writeObject(result);
IOUtil.close(socket);
}
/**
* 服务处理函数 通过包名+接口名
* 在servicePool中找到对应服务
* 通过调用方法参数类型数组获取Method对象
* 通过Method.invoke(对象,参数)调用对应服务
*/
private Object call(Message message) throws Exception {
String interfaceName = message.getInterfaceName();
Object service = servicePool.get(interfaceName);
if (service == null) {
return null;
}
Class<?> serviceClass = Class.forName(interfaceName);
Method method = serviceClass.getMethod(message.getMethodName(), message.getParamsTypes());
Object result = method.invoke(service, message.getParameters());
return result;
}
}
开启服务:
public static void main(String[] args){
Map<String,Object> servicePool = new HashMap<String, Object>();
// 先将服务确定好,才能调用,不允许客户端自动添加服务
servicePool.put(ISum.class.getName(), new SumImpl());
servicePool.put(IDiff.class.getName(), new DiffImpl());
RPCServer server = new RPCServer(servicePool,8080);
try {
server.service();
} catch (IOException e) {
e.printStackTrace();
}
}
这是一个通用的消息格式,用于在client和server之间传递消息,
消息主要包括;客户端需要调用的接口,方法名,参数类型
以及参数数组,这样是为了唯一确定调用的是哪个类的哪个方法,参数是什么,以及什么类型。就是通过反射调用的。
/**
* RPC调用条件
* 1.调用接口名称 (包名+接口名)
* 2.调用方法名
* 3.调用参数Class类型数组
* 4.调用接口的参数数组
*/
public class Message implements Serializable {
private static final long serialVersionUID = 1L;
// 包名+接口名称
private String interfaceName;
private String methodName;
private Class<?>[] paramsTypes;
private Object[] parameters;
public Message() {
}
public Message(String interfaceName, String methodName,
Class<?>[] paramsTypes, Object[] parameters) {
this.interfaceName = interfaceName;
this.methodName = methodName;
this.paramsTypes = paramsTypes;
this.parameters = parameters;
}
//setters and getters
}
上面是服务端的代码,主要就是使用接口名作为key,实现类作为值;
使用反射调用方法,再把结果用java序列化的方式写会客户端。
好了,我们再来看看客户端的实现。
/**
* @author root
*客户端比较简单,就是连接服务器
*然后用java序列化,把对象发送给服务器
*/
public class RPCClient {
// 服务端地址
private final String serverAddress;
// 服务端端口
private final int serverPort;
public RPCClient(String serverAddress, int serverPort) {
this.serverAddress = serverAddress;
this.serverPort = serverPort;
}
/**
* 同步的请求和接收结果
*/
public Object sendAndReceive(Message transportMessage) {
Object result = null;
Socket socket = null;
try {
socket = new Socket(serverAddress, serverPort);
// 反序列化 TransportMessage对象
ObjectOutputStream out = new ObjectOutputStream(socket.getOutputStream());
out.writeObject(transportMessage);
ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
// 阻塞等待读取结果并反序列化结果对象
result = in.readObject();
} catch (Exception e) {
e.printStackTrace();
} finally {
IOUtil.close(socket);
}
return result;
}
public String getServerAddress() {
return serverAddress;
}
public int getServerPort() {
return serverPort;
}
}
客户端请求服务器:
public class TestClient {
public static void main(String[] args) {
String serverAddress = "localhost";
int serverPort = 8080;
int count = 10;
final RPCClient client = new RPCClient(serverAddress, serverPort);
final Random random = new Random();
ExecutorService exe = Executors.newFixedThreadPool(count );
for (int i = 0; i < count; i++) {
exe.execute(new Runnable() {
public void run() {
Message transportMessage = null;
if(random.nextBoolean())
transportMessage = buildMessage();
else
transportMessage = buildMessage2();
Object result = client.sendAndReceive(transportMessage);
System.out.println(result);
}
});
}
exe.shutdown();
}
/**
* 创建一次消息调用
* @return
*/
public static Message buildMessage() {
String interfaceName = ISum.class.getName();
Class<?>[] paramsTypes = { int.class, int.class};
Object[] parameters = { 9, 3};
String methodName = "sum";
Message transportMessage = new Message(interfaceName,
methodName, paramsTypes, parameters);
return transportMessage;
}
public static Message buildMessage2() {
String interfaceName = IDiff.class.getName();
Class<?>[] paramsTypes = { double.class, double.class};
Object[] parameters = { 9.0, 3.0};
String methodName = "diff";
Message transportMessage = new Message(interfaceName,
methodName, paramsTypes, parameters);
return transportMessage;
}
}
大家可以看到客户端其实很简单:就是把消息发给服务器,然后接受调用结果。好了这就是一个rpc调用的过程,但是这个效率很低,低的原因主要是java序列化和阻塞的IO模型。在高并发的场景下,这种线程阻塞的方式,如果使用能适应高并发高效的NIO和更好的java序列化框架(比如protobuf),效果会更好。
先开启服务器,在开启客户端,这是一次客户端打印的结果: