package redis.exceptions;
public class MyException extends Exception {
public MyException() {
}
public MyException(String message) {
super(message);
}
public MyException(String message, Throwable cause) {
super(message, cause);
}
public MyException(Throwable cause) {
super(cause);
}
public MyException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
}
package redis.exceptions;
public class RemoteException extends MyException {
public RemoteException() {
}
public RemoteException(String message) {
super(message);
}
public RemoteException(String message, Throwable cause) {
super(message, cause);
}
public RemoteException(Throwable cause) {
super(cause);
}
public RemoteException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
}
package redis;
import redis.exceptions.MyException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.rmi.RemoteException;
import java.util.ArrayList;
import java.util.List;
public class Protocol {
private static String processSimpleString(InputStream is) throws IOException {
//"+OK\r\n"
return readLine(is);
}
private static String processError(InputStream is) throws IOException {
return readLine(is);
}
private static long processInteger(InputStream is) throws IOException {
return readInteger(is);
}
private static byte[] processBulkString (InputStream is) throws IOException{
int len = (int)readInteger(is);
if (len == -1) {
// "$-1\r\n" ==> null
return null;
}
byte[] r = new byte[len];
is.read(r,0,len);
/*
for (int i = 0; i < len; i++) {
int b = is.read();
r[i] = (byte)b;
}
*/
// "$5\r\nhello\r\n";
is.read();
is.read();
return r;
}
private static List<Object> processArray (InputStream is) throws IOException{
int len = (int)readInteger(is);
if (len == -1) {
// "*-1\r\n" ==> null
return null;
}
List<Object> list = new ArrayList<>(len);
for (int i = 0; i < len; i++) {
try {
list.add(process(is));
} catch (RemoteException e) {
list.add(e);
}
}
return list;
}
public static Object read(InputStream is) throws IOException{
return process(is);
}
private static Object process(InputStream is) throws IOException {
int b = is.read();
if (b == -1) {
throw new RuntimeException("不应该读到结尾的");
}
switch (b) {
case '+':
return processSimpleString(is);
case '-':
throw new RemoteException(processError(is));
case ':':
return processInteger(is);
case '$':
return processBulkString(is);
case '*':
return processArray(is);
default:
throw new RuntimeException("不识别的类型");
}
}
public static String readLine(InputStream is) throws IOException {
boolean needRead = true;
StringBuilder sb = new StringBuilder();
int b = -1;
while (true) {
if (needRead == true) {
b = is.read();
if (b == -1) {
throw new RuntimeException("不应该读到结尾的");
}
} else {
needRead = true;
}
if (b == '\r') {
int c = is.read();
if (c == -1) {
throw new RuntimeException("不应该读到结尾的");
}
if (c == '\n') {
break;
}
if (c == '\r') {
sb.append((char) b);
b = c;
needRead = false;
} else {
sb.append((char) b);
sb.append((char) c);
}
} else {
sb.append((char)b);
}
}
return sb.toString();
}
public static Command readCommand(InputStream is)throws Exception{
Object o = read(is);
//作为Server来说,一定不会收到"+OK\r\n"
if(!(o instanceof List)){
throw new MyException("命令必须是Array类型");
}
List<Object> list =(List<Object>)o;
if(list.size()<1){
throw new Exception("命令元素必须大于1");
}
Object o2 = list.remove(0);
if(!(o2 instanceof byte[])){
//是类的实例返回true,不是返回false
throw new Exception("错误的命令类型");
}
byte[] array = (byte[])o2;
String commandName = new String(array);
String className = String.format("redis.command.%sCommand",commandName.toUpperCase());
Class<?> cls = Class.forName(className);
//通过类名找到类,再进行实例化对象
if(!Command.class.isAssignableFrom(cls)){
throw new Exception("错误的命令");
}
Command command = (Command)cls.newInstance();
command.setArgs(list);
return command;
}
public static void writeError(OutputStream os, String message) throws IOException {
try {
os.write('-');
os.write(message.getBytes());
os.write("\r\n".getBytes());
} catch (IOException e) {
throw new IOException(e);
}
}
public static long readInteger(InputStream is) throws IOException {
boolean isNegative = false;
StringBuilder sb = new StringBuilder();
int b = is.read();
if (b == -1) {
throw new RuntimeException("不应该读到结尾");
}
if (b == '-') {
isNegative = true;
} else {
sb.append((char)b);
}
while (true) {
b = is.read();
if (b == -1) {
throw new RuntimeException("不应该读到结尾的");
}
if (b == '\r') {
int c = is.read();
if (c == -1) {
throw new RuntimeException("不应该读到结尾的");
}
if (c == '\n') {
break;
}
throw new RuntimeException("没有读到\\r\\n");
} else {
sb.append((char)b);
}
}
long v = Long.parseLong(sb.toString());
if (isNegative) {
v = -v;
}
return v;
}
public static void writeInteger(OutputStream os, long v) throws IOException {
try {
os.write(':');
os.write(String.valueOf(v).getBytes());
os.write("\r\n".getBytes());
} catch (IOException e) {
throw new IOException(e);
}
}
public static void writeArray(OutputStream os, List<?> list) throws Exception {
os.write('*');
os.write(String.valueOf(list.size()).getBytes());
os.write("\r\n".getBytes());
for (Object o : list) {
if (o instanceof String) {
writeBulkString(os, (String)o);
} else if (o instanceof Integer) {
writeInteger(os, (Integer)o);
} else if (o instanceof Long) {
writeInteger(os, (Long)o);
} else {
throw new Exception("错误的类型");
}
}
}
public static void writeBulkString(OutputStream os, String s) throws IOException {
byte[] buf = s.getBytes();
os.write('$');
os.write(String.valueOf(buf.length).getBytes());
os.write("\r\n".getBytes());
os.write(buf);
os.write("\r\n".getBytes());
}
public static void writeNull(OutputStream os) throws IOException {
os.write('$');
os.write('-');
os.write('1');
os.write('\r');
os.write('\n');
}
}
package redis;
import java.util.*;
public class Database {
private static Map<String, List<String>> lists = new HashMap<>();
private static Map<String, Map<String, String>> hashes = new HashMap<>();
public static List<String> getList(String key) {
/*
List<String> list = lists.computeIfAbsent(key, k -> {
return new ArrayList<>();
});
*/
List<String> list = lists.get(key);
if (list == null) {
list = new ArrayList<>();
lists.put(key, list);
}
return list;
}
public static Map<String, String> getHashes(String key) {
Map<String, String> hash = hashes.get(key);
if (hash == null) {
hash = new HashMap<>();
hashes.put(key, hash);
}
return hash;
}
}
package redis;
import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
public interface Command {
void setArgs(List<Object> args);
void run(OutputStream os)throws IOException;
}
package redis.command;
import redis.*;
import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
public class HGETCommand implements Command {
private List<Object> args;
@Override
public void setArgs(List<Object> args) {
this.args = args;
}
@Override
public void run(OutputStream os) throws IOException {
String key = new String((byte[])args.get(0));
String field = new String((byte[])args.get(1));
Map<String, String> hash = Database.getHashes(key);
String value = hash.get(field);
if (value != null) {
Protocol.writeBulkString(os, value);
} else {
Protocol.writeNull(os);
}
}
}
package redis.command;
import redis.*;
import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
public class HSETCommand implements Command {
private List<Object> args;
@Override
public void setArgs(List<Object> args) {
this.args = args;
}
@Override
public void run(OutputStream os) throws IOException {
String key = new String((byte[])args.get(0));
String field = new String((byte[])args.get(1));
String value = new String((byte[])args.get(2));
Map<String, String> hash =Database.getHashes(key);
boolean isUpdate = hash.containsKey(field);
hash.put(field, value);
if (isUpdate) {
Protocol.writeInteger(os, 0);
} else {
Protocol.writeInteger(os, 1);
}
}
}
package redis.command;
import redis.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.Command;
import redis.Database;
import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
public class LPUSHCommand implements Command {
private static final Logger logger = LoggerFactory.getLogger(LPUSHCommand.class);
private List<Object> args;
@Override
public void setArgs(List<Object> args) {
this.args = args;
}
@Override
public void run(OutputStream os) throws IOException {
if (args.size() != 2) {
Protocol.writeError(os, "error");
return;
}
String key = new String((byte[])args.get(0));
String value = new String((byte[])args.get(1));
logger.debug("运行的是 lpush 命令: {} {}", key, value);
// 这种方式不是一个很好的线程同步的方式
List<String> list = Database.getList(key);
list.add(0, value);
logger.debug("插入后数据共有 {} 个", list.size());
Protocol.writeInteger(os, list.size());
}
}
package redis.command;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.*;
import redis.Command;
import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
public class LRANGECommand implements Command {
private static final Logger logger = LoggerFactory.getLogger(LRANGECommand.class);
private List<Object> args;
@Override
public void setArgs(List<Object> args) {
this.args = args;
}
@Override
public void run(OutputStream os) throws IOException {
String key = new String((byte[]) args.get(0));
int start = Integer.parseInt(new String((byte[]) args.get(1)));
int end = Integer.parseInt(new String((byte[]) args.get(2)));
List<String> list = Database.getList(key);
if (end < 0) {
end = list.size() + end;
}
List<String> result = list.subList(start, end + 1);
try {
Protocol.writeArray(os, result);
} catch (Exception e) {
e.printStackTrace();
}
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>redis-like</groupId>
<artifactId>redis</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<encoding>UTF-8</encoding>
</properties>
<dependencies>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.26</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
</dependency>
</dependencies>
</project>
改进点
1.单线程(不支持并发连接) 引入多线程
2.连接的处理循环没有停下的条件 找到并人别出对方关闭连接的情况,跳出循环
3.错误处理不够丰富
4.持久化(存到磁盘/MySQL)