一、 代码结构如图
二、 实例编程步骤及代码
(1) 定义Calculate接口,提供加法和减法计算
package org.example.YarnRpcTest;
public interface Calculate {
public int add(int num1,int num2);
public int minus(int num1,int num2);
}
(2) 定义两个proto文件,CalculateMessage.proto定义请求的消息,CalculateServer.proto定义一个RPC服务。定义好后通过protoc编译器将两个文件编译为Calculate和CalculateMessage类
option java_package="calculateproto";
option java_outer_classname="CalculateMessage";
option java_generic_services=true;
option java_generate_equals_and_hash=true;
message RequestProto{
required string methodName=1;
required int32 num1=2;
required int32 num2=3;
}
message ResponseProto{
required int32 result=1;
}
option java_package="calculateproto";
option java_outer_classname="Calculate";
option java_generic_services=true;
option java_generate_equals_and_hash=true;
import "CalculateMessage.proto";
service CalculateService{
rpc add(RequestProto) returns (ResponseProto);
rpc minus(RequestProto) returns (ResponseProto);
}
(3) 定义CalculatePB接口,
package org.example.YarnRpcTest;
import calculateproto.Calculate;
public interface CalculatePB extends Calculate.CalculateService.BlockingInterface {
}
(4) 服务端程序Server
package org.example.YarnRpcTest;
import calculateproto.CalculateMessage.RequestProto;
import com.google.protobuf.BlockingService;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.Message;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.Arrays;
public class Server extends Thread {
private Class<?> protocol;
private BlockingService impl;
private int port;
private ServerSocket ss;
public Server(Class<?> protocol,BlockingService protocolImpl,int port){
this.protocol = protocol;
this.impl = protocolImpl;
this.port = port;
}
public void run(){
Socket clientSocket = null;
DataOutputStream dos = null;
DataInputStream dis = null;
try{
ss = new ServerSocket(port);
}catch (IOException e){
}
int Count = 10;
while(Count-- > 0){
try{
clientSocket = ss.accept();
dos = new DataOutputStream(clientSocket.getOutputStream());
dis = new DataInputStream(clientSocket.getInputStream());
int dataLen = dis.readInt();
byte[] dataBuffer = new byte[dataLen];
int readCount = dis.read(dataBuffer);
byte[] result = processRpc(dataBuffer);
System.out.println("run:result=" + result);
dos.writeInt(result.length);
dos.write(result);
dos.flush();
}catch (Exception e){
e.printStackTrace();
}
}
try {
dos.close();
dis.close();
ss.close();
System.out.println("socket close!");
}catch (Exception e){
}
}
public byte[] processRpc(byte[] data) throws Exception{
RequestProto request = RequestProto.parseFrom(data);
String methodName = request.getMethodName();
MethodDescriptor methodDescriptor=impl.getDescriptorForType().findMethodByName(methodName);
Message response = impl.callBlockingMethod(methodDescriptor,null,request);
System.out.println("response:" + response.toString());
return response.toByteArray();
}
}
(5) 服务端主程序CalculateService
package org.example.YarnRpcTest;
import com.google.protobuf.BlockingService;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
public class CalculateService implements Calculate {
private Server server = null;
private final Class protocol = Calculate.class;
private final ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
private final String protoPackage = "calculateproto";
private final String host="localhost";
private final int port=8888;
public CalculateService(){}
@Override
public int add(int num1, int num2) {
return num1 + num2;
}
@Override
public int minus(int num1, int num2) {
return num1 - num2;
}
public Class<?> getPbServiceImplClass(){
String packageName = protocol.getPackage().getName();
String ClassName = protocol.getSimpleName();
String pbServiceImplName = packageName + "." + ClassName + "PBServiceImpl";
Class<?> clazz = null;
try{
clazz = Class.forName(pbServiceImplName,true,classLoader);
}catch (ClassNotFoundException e){
System.err.println(e.toString());
}
return clazz;
}
public Class<?> getProtoClass(){
String className = protocol.getSimpleName();
String protoClazzName = protoPackage + "." + className + "$" + className + "Service";
Class<?> clazz = null;
try{
clazz = Class.forName(protoClazzName,true,classLoader);
}catch (ClassNotFoundException e){
System.err.println(e.toString());
}
return clazz;
}
public void createServer() {
Class<?> pbServiceImpl = getPbServiceImplClass();
Constructor<?> constructor = null;
try {
constructor = pbServiceImpl.getConstructor(protocol);
constructor.setAccessible(true);
} catch (Exception e) {
e.printStackTrace();
}
Object service = null;
try {
service = constructor.newInstance(this);
} catch (Exception e) {
e.printStackTrace();
}
Class<?> pbProtocol = service.getClass().getInterfaces()[0];
Class<?> protoClazz = getProtoClass();
Method method = null;
try {
method = protoClazz.getMethod("newReflectiveBlockingService", pbProtocol.getInterfaces()[0]);
method.setAccessible(true);
} catch (NoSuchMethodException e) {
System.err.println(e.toString());
}
try {
createServer(pbProtocol, (BlockingService) method.invoke(null, service));
} catch (Exception e) {
e.printStackTrace();
}
}
public void createServer(Class pbProtocol, BlockingService service){
server = new Server(pbProtocol,service,port);
server.start();
}
public void init(){
createServer();
}
public static void main(String[] args){
CalculateService server = new CalculateService();
server.init();
}
}
(6) 定义CalculatePBServiceImpl类,此类是PB格式最终实现的类
package org.example.YarnRpcTest;
import calculateproto.CalculateMessage.ResponseProto;
import calculateproto.CalculateMessage.RequestProto;
import com.google.protobuf.RpcController;
import com.google.protobuf.ServiceException;
public class CalculatePBServiceImpl implements CalculatePB {
public Calculate real;
public CalculatePBServiceImpl(Calculate impl){
this.real = impl;
}
public ResponseProto add(RpcController controller, RequestProto request) throws ServiceException {
ResponseProto proto = ResponseProto.getDefaultInstance();
ResponseProto.Builder build = ResponseProto.newBuilder();
int add1 = request.getNum1();
int add2 = request.getNum2();
int sum = real.add(add1,add2);
ResponseProto result = null;
build.setResult(sum);
result = build.build();
System.out.println("PBServiceImpl:add,result = " + result);
return result;
}
public ResponseProto minus(RpcController controller,RequestProto request) throws ServiceException{
ResponseProto proto = ResponseProto.getDefaultInstance();
ResponseProto.Builder builder = ResponseProto.newBuilder();
int a = request.getNum1();
int b = request.getNum2();
int sum = real.minus(a,b);
ResponseProto result = null;
builder.setResult(sum);
result = builder.build();
System.out.println("PBServiceImpl:minus,result = " + result);
return result;
}
}
(7) 客户端程序Client
package org.example.YarnRpcTest;
import calculateproto.CalculateMessage;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.net.Socket;
import java.util.Random;
public class Client implements Calculate {
public int domission(String op,int a,int b){
Socket s = null;
DataOutputStream out = null;
DataInputStream in = null;
int ret = 0;
try{
s = new Socket("localhost",8888);
out = new DataOutputStream(s.getOutputStream());
in = new DataInputStream(s.getInputStream());
CalculateMessage.RequestProto.Builder builder = CalculateMessage.RequestProto.newBuilder();
builder.setMethodName(op);
System.out.println("domission:op = " + op + ", a = " + a + ", b = " + b);
builder.setNum1(a);
builder.setNum2(b);
CalculateMessage.RequestProto request = builder.build();
byte[] bytes = request.toByteArray();
out.writeInt(bytes.length);
out.write(bytes);
out.flush();
int dataLen = in.readInt();
byte[] data = new byte[dataLen];
int count = in.read(data);
if(count != dataLen){
System.err.println("something bad happened");
}
CalculateMessage.ResponseProto result = CalculateMessage.ResponseProto.parseFrom(data);
ret = result.getResult();
}catch (Exception e){
e.printStackTrace();
}finally {
try {
in.close();
out.close();
s.close();
}catch (Exception e){
e.printStackTrace();
}
}
return ret;
}
@Override
public int add(int num1, int num2) {
return domission("add",num1,num2);
}
@Override
public int minus(int num1, int num2) {
return domission("minus",num1,num2);
}
public static void main(String[] args) {
Client client = new Client();
int testCount = 5;
Random rand = new Random();
while(testCount-- >0){
int a = rand.nextInt(100);
int b = rand.nextInt(100);
int addresult = client.add(a,b);
System.out.println("a:" + a + ",b:" + b + ",addresult=" + addresult);
int minusresult = client.minus(a,b);
System.out.println("a:" + a + ",b:" + b + ",minusresult=" + minusresult);
}
}
}