模仿redis的内存数据库

package redis.exceptions;

public class MyException extends Exception {
    public MyException() {
    }

    public MyException(String message) {
        super(message);
    }

    public MyException(String message, Throwable cause) {
        super(message, cause);
    }

    public MyException(Throwable cause) {
        super(cause);
    }

    public MyException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
        super(message, cause, enableSuppression, writableStackTrace);
    }
}
package redis.exceptions;

public class RemoteException extends MyException {
    public RemoteException() {
    }

    public RemoteException(String message) {
        super(message);
    }

    public RemoteException(String message, Throwable cause) {
        super(message, cause);
    }

    public RemoteException(Throwable cause) {
        super(cause);
    }

    public RemoteException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
        super(message, cause, enableSuppression, writableStackTrace);
    }
}

 

package redis;

import redis.exceptions.MyException;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.rmi.RemoteException;
import java.util.ArrayList;
import java.util.List;

public class Protocol {
    private static String processSimpleString(InputStream is) throws IOException {
        //"+OK\r\n"
        return readLine(is);
    }


    private static String processError(InputStream is) throws IOException {
        return readLine(is);
    }

    private static long processInteger(InputStream is) throws IOException {
        return readInteger(is);
    }


  private static byte[] processBulkString (InputStream is) throws IOException{
      int len = (int)readInteger(is);
      if (len == -1) {
      // "$-1\r\n"    ==> null
          return null;
      }

      byte[] r = new byte[len];
      is.read(r,0,len);
      /*
      for (int i = 0; i < len; i++) {
          int b = is.read();
          r[i] = (byte)b;
      }
      */
      // "$5\r\nhello\r\n";
      is.read();
      is.read();

      return r;

  }
    private static List<Object> processArray (InputStream is) throws IOException{
        int len = (int)readInteger(is);
        if (len == -1) {
            // "*-1\r\n"        ==> null
            return null;
        }

        List<Object> list = new ArrayList<>(len);
        for (int i = 0; i < len; i++) {
            try {
                list.add(process(is));
            } catch (RemoteException e) {
                list.add(e);
            }
        }

        return list;

    }

    public static Object read(InputStream is)  throws IOException{
        return process(is);
    }
    private static Object process(InputStream is) throws IOException {
        int b = is.read();
        if (b == -1) {
            throw new RuntimeException("不应该读到结尾的");
        }

        switch (b) {
            case '+':
                return processSimpleString(is);
            case '-':
                throw new RemoteException(processError(is));
            case ':':
                return processInteger(is);
            case '$':
                return processBulkString(is);
            case '*':
                return processArray(is);
            default:
                throw new RuntimeException("不识别的类型");
        }
    }

    public static String readLine(InputStream is) throws IOException {
        boolean needRead = true;
        StringBuilder sb = new StringBuilder();
        int b = -1;
        while (true) {
            if (needRead == true) {
                b = is.read();
                if (b == -1) {
                    throw new RuntimeException("不应该读到结尾的");
                }
            } else {
                needRead = true;
            }

            if (b == '\r') {
                int c = is.read();
                if (c == -1) {
                    throw new RuntimeException("不应该读到结尾的");
                }

                if (c == '\n') {
                    break;
                }

                if (c == '\r') {
                    sb.append((char) b);
                    b = c;
                    needRead = false;
                } else {
                    sb.append((char) b);
                    sb.append((char) c);
                }
            } else {
                sb.append((char)b);
            }
        }
        return sb.toString();
    }
    public static Command readCommand(InputStream is)throws  Exception{
        Object o = read(is);
        //作为Server来说,一定不会收到"+OK\r\n"
        if(!(o instanceof List)){
            throw new MyException("命令必须是Array类型");
        }
        List<Object> list =(List<Object>)o;
        if(list.size()<1){
            throw new Exception("命令元素必须大于1");
        }
        Object o2 = list.remove(0);
        if(!(o2 instanceof byte[])){
            //是类的实例返回true,不是返回false
            throw new Exception("错误的命令类型");
        }
        byte[] array  = (byte[])o2;
        String commandName = new String(array);
        String className = String.format("redis.command.%sCommand",commandName.toUpperCase());
        Class<?> cls = Class.forName(className);
        //通过类名找到类,再进行实例化对象
        if(!Command.class.isAssignableFrom(cls)){
            throw new Exception("错误的命令");
        }
        Command command = (Command)cls.newInstance();
        command.setArgs(list);
        return command;
    }

    public static void writeError(OutputStream os, String message) throws IOException {
        try {
            os.write('-');
            os.write(message.getBytes());
            os.write("\r\n".getBytes());
        } catch (IOException e) {
            throw new IOException(e);
        }
    }

    public static long readInteger(InputStream is) throws IOException {
        boolean isNegative = false;
        StringBuilder sb = new StringBuilder();
        int b = is.read();
        if (b == -1) {
            throw new RuntimeException("不应该读到结尾");
        }

        if (b == '-') {
            isNegative = true;
        } else {
            sb.append((char)b);
        }

        while (true) {
            b = is.read();
            if (b == -1) {
                throw new RuntimeException("不应该读到结尾的");
            }

            if (b == '\r') {
                int c = is.read();
                if (c == -1) {
                    throw new RuntimeException("不应该读到结尾的");
                }

                if (c == '\n') {
                    break;
                }

                throw new RuntimeException("没有读到\\r\\n");
            } else {
                sb.append((char)b);
            }
        }

        long v = Long.parseLong(sb.toString());
        if (isNegative) {
            v = -v;
        }

        return v;
    }
    public static void writeInteger(OutputStream os, long v) throws IOException {
         try {
             os.write(':');
             os.write(String.valueOf(v).getBytes());
             os.write("\r\n".getBytes());
         } catch (IOException e) {
             throw new IOException(e);
         }
    }

    public static void writeArray(OutputStream os, List<?> list) throws Exception {
        os.write('*');
        os.write(String.valueOf(list.size()).getBytes());
        os.write("\r\n".getBytes());
        for (Object o : list) {
            if (o instanceof String) {
                writeBulkString(os, (String)o);
            } else if (o instanceof Integer) {
                writeInteger(os, (Integer)o);
            } else if (o instanceof Long) {
                writeInteger(os, (Long)o);
            } else {
                throw new Exception("错误的类型");
            }
        }
    }

    public static void writeBulkString(OutputStream os, String s) throws IOException {
        byte[] buf = s.getBytes();
        os.write('$');
        os.write(String.valueOf(buf.length).getBytes());
        os.write("\r\n".getBytes());
        os.write(buf);
        os.write("\r\n".getBytes());
    }

    public static void writeNull(OutputStream os) throws IOException {
        os.write('$');
        os.write('-');
        os.write('1');
        os.write('\r');
        os.write('\n');
    }

}

package redis;

import java.util.*;

public class Database {
    private static Map<String, List<String>> lists = new HashMap<>();
    private static Map<String, Map<String, String>> hashes = new HashMap<>();


    public static List<String> getList(String key) {
        /*
        List<String> list = lists.computeIfAbsent(key, k -> {
            return new ArrayList<>();
        });
         */

        List<String> list =  lists.get(key);
        if (list == null) {
            list = new ArrayList<>();
            lists.put(key, list);
        }

        return list;
    }

    public static Map<String, String> getHashes(String key) {
        Map<String, String> hash =  hashes.get(key);
        if (hash == null) {
            hash = new HashMap<>();
            hashes.put(key, hash);
        }

        return hash;
    }
}
package redis;

import java.io.IOException;
import java.io.OutputStream;
import java.util.List;

public interface Command {
    void setArgs(List<Object> args);

    void run(OutputStream os)throws IOException;
}
package redis.command;

import redis.*;

import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;

public class HGETCommand implements Command {
    private List<Object> args;

    @Override
    public void setArgs(List<Object> args) {
        this.args = args;
    }

    @Override
    public void run(OutputStream os) throws IOException {
        String key = new String((byte[])args.get(0));
        String field = new String((byte[])args.get(1));

        Map<String, String> hash = Database.getHashes(key);
        String value = hash.get(field);
        if (value != null) {
            Protocol.writeBulkString(os, value);
        } else {
            Protocol.writeNull(os);
        }
    }
}

package redis.command;

import redis.*;

import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;

public class HSETCommand implements Command {
    private List<Object> args;

    @Override
    public void setArgs(List<Object> args) {
        this.args = args;
    }

    @Override
    public void run(OutputStream os) throws IOException {
        String key = new String((byte[])args.get(0));
        String field = new String((byte[])args.get(1));
        String value = new String((byte[])args.get(2));
        Map<String, String> hash =Database.getHashes(key);
        boolean isUpdate = hash.containsKey(field);
        hash.put(field, value);
        if (isUpdate) {
            Protocol.writeInteger(os, 0);
        } else {
            Protocol.writeInteger(os, 1);
        }
    }
}

package redis.command;
import redis.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.Command;
import redis.Database;

import java.io.IOException;
import java.io.OutputStream;
import java.util.List;

public class LPUSHCommand implements Command {
    private static final Logger logger = LoggerFactory.getLogger(LPUSHCommand.class);
    private List<Object> args;

    @Override
    public void setArgs(List<Object> args) {
        this.args = args;
    }

    @Override
    public void run(OutputStream os) throws IOException {
        if (args.size() != 2) {
            Protocol.writeError(os, "error");
            return;
        }
        String key = new String((byte[])args.get(0));
        String value = new String((byte[])args.get(1));
        logger.debug("运行的是 lpush 命令: {} {}", key, value);

        // 这种方式不是一个很好的线程同步的方式
        List<String> list = Database.getList(key);
        list.add(0, value);

        logger.debug("插入后数据共有 {} 个", list.size());

        Protocol.writeInteger(os, list.size());
    }
}
package redis.command;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.*;

import redis.Command;

import java.io.IOException;
import java.io.OutputStream;
import java.util.List;

public class LRANGECommand implements Command {
    private static final Logger logger = LoggerFactory.getLogger(LRANGECommand.class);
    private List<Object> args;

    @Override
    public void setArgs(List<Object> args) {
        this.args = args;
    }

    @Override
    public void run(OutputStream os) throws IOException {
        String key = new String((byte[]) args.get(0));
        int start = Integer.parseInt(new String((byte[]) args.get(1)));
        int end = Integer.parseInt(new String((byte[]) args.get(2)));

        List<String> list = Database.getList(key);
        if (end < 0) {
            end = list.size() + end;
        }
        List<String> result = list.subList(start, end + 1);
        try {
            Protocol.writeArray(os, result);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>redis-like</groupId>
    <artifactId>redis</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <encoding>UTF-8</encoding>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>1.7.26</version>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>1.2.3</version>
        </dependency>
    </dependencies>
</project>

改进点

1.单线程(不支持并发连接) 引入多线程

2.连接的处理循环没有停下的条件 找到并人别出对方关闭连接的情况,跳出循环

3.错误处理不够丰富

4.持久化(存到磁盘/MySQL)

 

 

 

 

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值