手写实现简单Redis命令客户端功能

RESP协议

Redis 的客户端和服务端之间采取了一种名为 Redis序列化的协议(REdis Serialization Protocol,简称RESP),是基于 TCP 的应用层协议 ,RESP 底层采用的是 TCP 的连接方式,通过 TCP 进行数据传输,然后根据解析规则解析相应信息。

在RESP协议中,数据的类型取决于第一个字节:

  • +开始表示单行字符串
  • -开始表示错误类型
  • :开始表示整数
  • $开始表示多行字符串
  • *开始表示数组

在RESP协议中,构成协议的每一部分必须使用\r\n作为结束符

命令格式示例:

SET key value

*3\r\n    #3表示这个命令由3部分组成
$3\r\n    # 第一部分的长度是3
SET\r\n   # 第一部分的内容
$3\r\n    # 第二部分的长度是3
key\r\n   # 第二部分的内容
$5\r\n    # 第三部分的长度是5
value\r\n # 第三部分的内容

几种响应格式示例:

简单字符串:

+OK\r\n


错误(Errors):

-ERR Error message\r\n


整数(Integers):

:1000\r\n


批量字符串(Bulk Strings):

$6\r\n
foobar\r\n

数组(Arrays)(示例包含两个元素,分别为 "foo" 和 "bar"):

*2\r\n
$3\r\n
foo\r\n
$3\r\n
bar\r\n
手写通过RESP协议实现客户端
package com.qf.redis.client;


import java.io.*;
import java.net.Socket;

public class RedisClient {
    public static void main(String[] args) {
        String ip = "localhost";
        int port = 6379;
        try {
            Socket socket = new Socket(ip,port);
            OutputStream os = socket.getOutputStream();
            //发送一个redis命令: set class java
            String[] commends = new String[]{"set","class","java"};
            String s = String.format("*%d\r\n", commends.length);
            os.write(s.getBytes());//os.write("*3\r\n".getBytes());
            for (int i = 0; i < commends.length; i++) {
                String s1 = String.format("$%d\r\n", commends[i].length());
                os.write(s1.getBytes());
                String s2 = String.format("%s\r\n", commends[i]);
                os.write(s2.getBytes());
            }
//            os.write("$3\r\n".getBytes());
//            os.write("set\r\n".getBytes());
//            os.write("$5\r\n".getBytes());
//            os.write("class\r\n".getBytes());
//            os.write("$4\r\n".getBytes());
//            os.write("java\\r\n".getBytes());
            os.flush();
            socket.shutdownOutput();//告诉服务器发送已完毕
            InputStream in = socket.getInputStream();
            BufferedReader reader = new BufferedReader(new InputStreamReader(in));
            String response;
            while ((response=reader.readLine())!=null){
                System.out.println(response);
            }
            reader.close();
            socket.close();

        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }
}

注:Redis采用的是RSEP协议进行通信,这种协议只支持传输字符串或者字节数组,传输字符串的时候,容器出现中文乱码问题,具体来说,是因为Redis默认的编码格式和java不同,所以在获取长度的时候,上面那种获取的是中文的长度,而不同编码格式对于中文汉字的长度可能不一样,从而导致Redis拿到长度后发现长度和对象不匹配,导致乱码。因此,推荐使用字节数组进行传输,这样计算的长度都是字节数组的长度,就不会发生乱码问题。

package com.qf.redis.client;


import java.io.*;
import java.net.Socket;

public class RedisClient {
    public static void main(String[] args) {
        String ip = "localhost";
        int port = 6379;
        try {
            Socket socket = new Socket(ip,port);
            OutputStream os = socket.getOutputStream();
            //发送一个redis命令: set class java
//            String[] commends = new String[]{"set","stu","zhangsan"};
            byte[][] commends = new byte[][]{
                    "set".getBytes(),
                    "stu".getBytes(),
                    "张三".getBytes()
            };
            String s = String.format("*%d\r\n", commends.length);
            os.write(s.getBytes());//os.write("*3\r\n".getBytes());
            for (byte[] commend : commends) {
                os.write("$".getBytes());
                os.write(Integer.toString(commend.length).getBytes());
                os.write("\r\n".getBytes());
                os.write(commend);
                os.write("\r\n".getBytes());
            }
//            for (int i = 0; i < commends.length; i++) {
//                String s1 = String.format("$%d\r\n", commends[i].length());
//                os.write(s1.getBytes());
//                String s2 = String.format("%s\r\n", commends[i]);
//                os.write(s2.getBytes());
//            }
//            os.write("$3\r\n".getBytes());
//            os.write("set\r\n".getBytes());
//            os.write("$5\r\n".getBytes());
//            os.write("class\r\n".getBytes());
//            os.write("$4\r\n".getBytes());
//            os.write("java\\r\n".getBytes());
            os.flush();
            socket.shutdownOutput();//告诉服务器发送已完毕
            InputStream in = socket.getInputStream();
            BufferedReader reader = new BufferedReader(new InputStreamReader(in));
            String response;
            while ((response=reader.readLine())!=null){
                System.out.println(response);
            }
            reader.close();
            socket.close();

        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }

}

对于获得的结果,目前的处理是将结果直接进行打印,这样的处理并不合理,我们应该将获取结果封装为一个方法,根据不同的情况得到不同的返回值

package com.qf.redis.client;


import java.io.*;
import java.net.Socket;
import java.util.ArrayList;
import java.util.List;

public class RedisClient {
    public static void main(String[] args) {
        String ip = "localhost";
        int port = 6379;
        try {
            Socket socket = new Socket(ip,port);
            
            //发送一个redis命令: set class java
//            String[] commends = new String[]{"set","stu","zhangsan"};
            byte[][] commends = new byte[][]{
                    "get".getBytes(),
                    "stu".getBytes()
//                    "张三".getBytes()
            };
            sendCommend(socket, os, commends);

            Object result = getResult(socket);
            System.out.println(result);

        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }

    private static Object getResult(Socket socket) throws IOException {
        InputStream in = socket.getInputStream();
        BufferedReader reader = new BufferedReader(new InputStreamReader(in));
        char read = (char) in.read();
        switch (read){
            case '+'://单行字符串
                return reader.readLine();
            case '*'://多行字符串,以数组的形式出现
                String str = reader.readLine();
                int rows = Integer.parseInt(str);
                List<String> results = new ArrayList<>();
                while (rows-- > 0){
                    reader.readLine();
                    results.add(reader.readLine());
                }
                return results;
            case ':'://单行字符串,表示整数
                str = reader.readLine();
                return Integer.parseInt(str);
            case '$'://多行字符串
                reader.readLine();
                return reader.readLine();
            case '-'://错误
                StringBuilder sb = new StringBuilder();
                String line;
                while((line = reader.readLine())!=null){
                    sb.append(line);
                }
                throw new RuntimeException(sb.toString());
            default:
                return "";
        }
//        String response;
//        while ((response=reader.readLine())!=null){
//            System.out.println(response);
//        }
//        reader.close();
//        socket.close();
    }

    private static void sendCommend(Socket socket, byte[][] commends) throws IOException {
        String s = String.format("*%d\r\n", commends.length);
        OutputStream os = socket.getOutputStream();
        os.write(s.getBytes());//os.write("*3\r\n".getBytes());
        for (byte[] commend : commends) {
            os.write("$".getBytes());
            os.write(Integer.toString(commend.length).getBytes());
            os.write("\r\n".getBytes());
            os.write(commend);
            os.write("\r\n".getBytes());
        }
        os.flush();
        socket.shutdownOutput();//告诉服务器发送已完毕
    }

}

以上是通过指定的命令进行编写,我们可以将get、set命令单独编写方法实现,但是在方法中如果想要实现设置和获取的对象不是简单的字符串,而是一个对象(比如集合),该怎么实现呢?

Redis本身主要就是用来做高速缓存的,因此其中可能会缓存一些查询数据,但这些数据在Java应用中都是以集合形式出现,因此需要提供一种方式能够将集合转换成字节数组,这样就能实现将集合数据存储在Redis中。

这个过程称为序列化,而将数据转换回集合并读取的过程叫做反序列化,序列化接口:

RedisSerializer

public interface RedisSerializer<T> {

    //这个接口方法就是将给定的对象转换成字节数组

    byte[] serialize(T t) throws IOException;

    //这个接口方法就是将给定的byte数组还原成T对象
    T deserialize(byte[] data) throws IOException, ClassNotFoundException;

}

实现类,包含两种,一种是输入格式为字符串,一种是字节数组:

字符串类型很好转换,因为String类型本身就提供了和字节数组之间的转换

public class StringRedisSerializer implements RedisSerializer<String>{
    @Override
    public byte[] serialize(String s) throws IOException {
        return s.getBytes();
    }

    @Override
    public String deserialize(byte[] data) throws IOException, ClassNotFoundException {
        return new String(data);
    }
}

字节数组相对来说比较麻烦,需要实现字节数据和集合对象之间的转换,有两种思路,第一种是通过反射获取对象的属性,然后转换为json格式(即map)的数据,再转换为字节数组,反序列化则是将字节数组转换为字符串,然后反射获取要转换对象的属性信息,就需要考虑不同属性类型进而进行不同的转化方式(对象属性要求装箱),十分复杂;

第二种则是通过对象流的形式直接进行读写,因此推荐这种方式,这里也只以这种方式进行实现

public class GenericObjectSerializer<T> implements RedisSerializer<T>{
    
    private Class<T> clazz;

    public GenericObjectSerializer(Class<T> clazz) {
        this.clazz = clazz;
    }

    @Override
    public byte[] serialize(T t) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ObjectOutputStream oos = new ObjectOutputStream(baos);
        oos.writeObject(t);
        oos.close();
        return baos.toByteArray();

    }

    @Override
    public T deserialize(byte[] data) throws IOException, ClassNotFoundException {
        ByteArrayInputStream bais = new ByteArrayInputStream(data);
        ObjectInputStream ois = new ObjectInputStream(bais);
        T t = (T) ois.readObject();
        ois.close();
        return t;
    }
}

实现了字符串和字节数组的序列化和反序列化后,可以进行对get和set方法的编写,注意:这里的get和set已经不仅仅可以实现Redis中的原生的设置字符串和获取字符串,而是可以设置对象和获取对象。

package com.qf.redis.client;


import java.io.*;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class RedisClient {
    public static void main(String[] args) {
//        String ip = "localhost";
//        int port = 6379;
//        try {
//            Socket socket = new Socket(ip,port);
//            OutputStream os = socket.getOutputStream();
//            //发送一个redis命令: set class java
            String[] commends = new String[]{"set","stu","zhangsan"};
//            byte[][] commends = new byte[][]{
//                    "get".getBytes(),
//                    "stu".getBytes()
                    "张三".getBytes()
//            };
//            sendCommend(socket, os, commends);
//
//            Object result = getResult(socket);
//            System.out.println(result);
//
//        } catch (IOException e) {
//            throw new RuntimeException(e);
//        }

//        set("hobbies", Arrays.asList("吃饭","睡觉","撸狗","撸猫","撸水豚"));
        get("hobbies");
    }

    private static Object getResult(Socket socket) throws IOException {
        InputStream in = socket.getInputStream();
        BufferedReader reader = new BufferedReader(new InputStreamReader(in));
        char read = (char) in.read();
        switch (read){
            case '+'://单行字符串
                return reader.readLine();
            case '*'://多行字符串,以数组的形式出现
                String str = reader.readLine();
                int rows = Integer.parseInt(str);
                List<String> results = new ArrayList<>();
                while (rows-- > 0){
                    reader.readLine();
                    results.add(reader.readLine());
                }
                return results;
            case ':'://单行字符串,表示整数
                str = reader.readLine();
                return Integer.parseInt(str);
            case '$'://多行字符串
                reader.readLine();
                return reader.readLine();
            case '-'://错误
                StringBuilder sb = new StringBuilder();
                String line;
                while((line = reader.readLine())!=null){
                    sb.append(line);
                }
                throw new RuntimeException(sb.toString());
            default:
                return "";
        }
//        String response;
//        while ((response=reader.readLine())!=null){
//            System.out.println(response);
//        }
//        reader.close();
//        socket.close();
    }

    private static void sendCommend(Socket socket, byte[][] commends) throws IOException {
        String s = String.format("*%d\r\n", commends.length);
        OutputStream os = socket.getOutputStream();
        os.write(s.getBytes());//os.write("*3\r\n".getBytes());
        for (byte[] commend : commends) {
            os.write("$".getBytes());
            os.write(Integer.toString(commend.length).getBytes());
            os.write("\r\n".getBytes());
            os.write(commend);
            os.write("\r\n".getBytes());
        }
        os.flush();
        socket.shutdownOutput();//告诉服务器发送已完毕
    }



    public static void get(String key){
        StringRedisSerializer keySerializer = new StringRedisSerializer();
        RedisSerializer<List> valueSerializer = new GenericObjectSerializer<>(List.class);

        try {
            byte[] keyData = keySerializer.serialize(key);
            byte[][] commands = {"get".getBytes(),keyData};//{"get".getBytes(),"name".getBytes()}
            Socket socket = new Socket("localhost",6379);
            sendCommend(socket,commands);
            byte[] result = receiveMsg(socket);
            System.out.println(Arrays.toString(result));//没进行反序列化的结果,是字节数组强行转换为字符串的样子
            List list = valueSerializer.deserialize(result);
            System.out.println(list);//进行了反序列化后的结果
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static void set(String key,Object value){
        StringRedisSerializer keySerializer = new StringRedisSerializer();
        GenericObjectSerializer<Object> valueSerializer = new GenericObjectSerializer<>(Object.class);

        try {
            byte[] keyData = keySerializer.serialize(key);
            byte[] valueData = valueSerializer.serialize(value);
            byte[][] commands = {"set".getBytes(),keyData,valueData};//{"set".getBytes(),"name".getBytes(),"张三".getBytes()}
            Socket socket = new Socket("localhost",6379);
            sendCommend(socket,commands);
            Object result = getResult(socket);
            System.out.println(result);

        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }

    /**
     * 这个方法是将序列化结果的字节数组读取出来
     * @param socket
     * @return
     */
    public static byte[] receiveMsg(Socket socket){
        try {
            InputStream in = socket.getInputStream();
            while (true){
                char c = (char) in.read();
                if(c == '\r'){
                    c = (char) in.read();
                    if(c == '\n')
                        break;
                }
            }
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            byte[] buffer = new byte[2048];
            int len;
            while ((len = in.read(buffer))!=-1){
                baos.write(buffer,0,len);
            }
            byte[] result = baos.toByteArray();
            baos.close();
            return result;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }
}

注:这里的get方法中不能用getResult()方法直接获取结果,因为这个方法只针对于没有进行序列化就存储的数据,而这里的数据需要反序列化,所以添加了receiveMsg()方法以将获取的数据视为字节数组并进行反序列化处理。

拓展:

但是假如我需要同时利用get和set方法存储序列化和非序列化的数据,该如何实现呢?

可以在set和get方法的参数中加一个布尔值参数,以区分是否需要进行序列化,如果需要则进行序列化

package com.qf.redis.client;


import java.io.*;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class RedisClient {
    public static void main(String[] args) {

        set("hobbies", Arrays.asList("吃饭","睡觉","撸狗","撸猫","撸水豚"),true);
        get("hobbies",true);
//        get("stu",false);
//        set("tea","boduolaoshi",false);
//        get("tea",false);
    }


    private static Object getResult(Socket socket) throws IOException {
        InputStream in = socket.getInputStream();
        BufferedReader reader = new BufferedReader(new InputStreamReader(in));
        char read = (char) in.read();
        switch (read){
            case '+'://单行字符串
                return reader.readLine();
            case '*'://多行字符串,以数组的形式出现
                String str = reader.readLine();
                int rows = Integer.parseInt(str);
                List<String> results = new ArrayList<>();
                while (rows-- > 0){
                    reader.readLine();
                    results.add(reader.readLine());
                }
                return results;
            case ':'://单行字符串,表示整数
                str = reader.readLine();
                return Integer.parseInt(str);
            case '$'://多行字符串
                reader.readLine();
                return reader.readLine();
            case '-'://错误
                StringBuilder sb = new StringBuilder();
                String line;
                while((line = reader.readLine())!=null){
                    sb.append(line);
                }
                throw new RuntimeException(sb.toString());
            default:
                return "";
        }
//        String response;
//        while ((response=reader.readLine())!=null){
//            System.out.println(response);
//        }
//        reader.close();
//        socket.close();
    }

    private static void sendCommend(Socket socket, byte[][] commends) throws IOException {
        String s = String.format("*%d\r\n", commends.length);
        OutputStream os = socket.getOutputStream();
        os.write(s.getBytes());//os.write("*3\r\n".getBytes());
        for (byte[] commend : commends) {
            os.write("$".getBytes());
            os.write(Integer.toString(commend.length).getBytes());
            os.write("\r\n".getBytes());
            os.write(commend);
            os.write("\r\n".getBytes());
        }
        os.flush();
        socket.shutdownOutput();//告诉服务器发送已完毕
    }



    public static void get(String key,boolean isSerialize){
        StringRedisSerializer keySerializer = new StringRedisSerializer();
        RedisSerializer<List> valueSerializer = new GenericObjectSerializer<>(List.class);

        try {
            byte[] keyData = keySerializer.serialize(key);
            byte[][] commands = {"get".getBytes(),keyData};//{"get".getBytes(),"name".getBytes()}
            Socket socket = new Socket("localhost",6379);
            sendCommend(socket,commands);
            if (isSerialize){
                byte[] result = receiveMsg(socket);
                List list = valueSerializer.deserialize(result);
                System.out.println(list);//进行了反序列化后的结果


            }else {
                Object result = getResult(socket);
                System.out.println(result);
            }
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static void set(String key,Object value,boolean isSerialize){
        StringRedisSerializer keySerializer = new StringRedisSerializer();
        GenericObjectSerializer<Object> valueSerializer = new GenericObjectSerializer<>(Object.class);

        try {
            byte[] keyData = keySerializer.serialize(key);
            byte[] valueData;
            if(isSerialize){
                valueData = valueSerializer.serialize(value);
            }else {
                valueData = String.valueOf(value).getBytes();
            }

            byte[][] commands  = new byte[][]{"set".getBytes(), keyData, valueData};
            Socket socket = new Socket("localhost",6379);
            sendCommend(socket,commands);
            Object result = getResult(socket);
            System.out.println(result);

        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }

    /**
     * 这个方法是将序列化结果的字节数组读取出来
     * @param socket
     * @return
     */
    public static byte[] receiveMsg(Socket socket){
        try {
            InputStream in = socket.getInputStream();
            while (true){
                char c = (char) in.read();
                if(c == '\r'){
                    c = (char) in.read();
                    if(c == '\n')
                        break;
                }
            }
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            byte[] buffer = new byte[2048];
            int len;
            while ((len = in.read(buffer))!=-1){
                baos.write(buffer,0,len);
            }
            byte[] result = baos.toByteArray();
            baos.close();
            return result;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }
}

  • 18
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值