使用原生java手写一个简单web服务器

   最近由于网络编程课程的结课,要求我们写一个大作业,要求使用socket,文件读写以及协议编程写一个项目。我就想着python的nio玩过,但还没怎么用java的nio写一个项目,于是灵机一动,花了将近两天,写了个简单的web服务器。代码方面没有导入其他包,而使用jdk自带工具包。下面看成果图

 

目录

成果截图

项目结构

实验原理

Connector

HTTPRequest

HTTPResponse

Controller

SessionFactory

UserSet


 

成果截图

首页:

登录失败:

登录成功:重定向到user/index.html页面

 

其他页面就不放出来了

 

项目结构

注意最后一行是user.txt文件,你当成数据库文件也行,我这边是这样,简单使用/%abc%/来当分隔符,左边是username,右边是password

 

实验原理

我们都知道socket是操作系统提供给我们的工具,让我们不用去管tcp跟udp的细节,而封装成的一个工具给编程人员使用。

因此,web服务器的原理,就是

1. 先用socket监听端口,使用tcp与浏览器进行通信

2. 在tcp的上层,业即应用层去使用http协议

3. 通常的业务处理

而在此之上,由于http是无状态的协议,因此此项目还手写了一个session来进行保存

并且,由于没有链接数据库,便使用文件读写来保存数据对象,此项目只简单序列化了下User对象

 

Connector

首先是连接器,此连接器使用的是java的nio。

使用一个主线程,设为accept状态专门用来接收客户端来的tcp连接,并再创建一个socket去处理此tcp连接的服务。由于是http协议是半双工协议,因此来自客户端的请求只会是read。

虽然使用nio可以不用多线程,但是对于这种IO密集型的程序,这里仍然使用一个线程池去处理各个客户端的请求,以提升处理速度。

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;


public class Connector {

    private Selector selector;
    private ServerSocketChannel serverSocketChannel;
    private Logger logger = Logger.getLogger(this.toString());
    private ThreadLocal<ByteBuffer> sendMap = new ThreadLocal<>();
    private ThreadLocal<ByteBuffer> receiveMap = new ThreadLocal<>();
    private ByteBuffer send = ByteBuffer.allocate(10000);
    private ByteBuffer receive = ByteBuffer.allocate(10000);
    private ThreadPoolExecutor threadPoolExecutor;
    boolean flag = false;

    public void init(int port) throws IOException {
        // 初始化线程池
        threadPoolExecutor =  new ThreadPoolExecutor(4, 10,
                60L, TimeUnit.SECONDS,
                new SynchronousQueue<Runnable>());

        serverSocketChannel = ServerSocketChannel.open();
        // 配置为非阻塞
        serverSocketChannel.configureBlocking(false);
//        ServerSocket serverSocket = serverSocketChannel.socket();
        // 绑定端口
        InetSocketAddress address = new InetSocketAddress(port);
        serverSocketChannel.bind(address);
        selector = Selector.open();
        // 在socket上使用select, 把链接放到accept槽位
        serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
        logger.info("服务正在启动....绑定的端口为" + port);
    }

    public void listen(){
        while (true){
            Iterator<SelectionKey> iterator = null;
            Set<SelectionKey> selectionKeys = null;
            try {
                selector.select();
                selectionKeys = selector.selectedKeys();
                // 遍历所有槽位
                iterator = selectionKeys.iterator();
                while (iterator.hasNext()){
                    SelectionKey key = iterator.next();
                    dealChannel(key);
                    // 把相应位置删除,免得无限处理同一个
                    iterator.remove();

                }
            } catch (Exception e) {
                e.printStackTrace();
                if(iterator != null)
                    iterator.remove();
//                System.out.println("----------");
            }
        }
    }

    // 判断key是accept,read,write哪个槽中,进行相应处理,创建相应信道
    private void dealChannel(SelectionKey selectionKey) throws Exception {
        SocketChannel client = null;
        try {
            if(!selectionKey.isValid()) {
                selectionKey.cancel();
                return;
            }
            // 若是主TCP连接,则发放客户端链接给子信道处理
            if (selectionKey.isAcceptable()) {
                // 获取父信道,转换成ServerSocketChannel是因为此接口才有accept的阻塞方法,不然得轮询
                ServerSocketChannel serverChannel = (ServerSocketChannel) selectionKey.channel();
                // 平常阻塞着,一旦有请求连接,主Channel创建子channel来处理业务
                client = serverChannel.accept();
                client.configureBlocking(false);
                // 子信道为了处理业务,所以可读
                client.register(selector, SelectionKey.OP_READ);

            }
            // 如果是可读的
            else if (selectionKey.isReadable()) {
                selectionKey.cancel();
                client = (SocketChannel) selectionKey.channel();
                if (!client.isOpen())
                    return;
                ByteBuffer byteBuffer = ByteBuffer.allocate(4096);
                int count = client.read(byteBuffer);
                if (count <= 0) {
                    // 关闭都交给子线程
//                    client.close();
                    return;
                }

                final SocketChannel threadClient = client;
                byteBuffer.position(0);
                // 使用线程池执行业务逻辑
                threadPoolExecutor.execute(() ->{
                    try {
                        // 解析http
                        HTTPRequest headers = new HTTPParse().parse(byteBuffer);
                        // 处理并返回response给客户端
                        new Controller().control(headers, threadClient);
                    }catch (IOException e){
                        e.printStackTrace();
                    }finally {
                        try {
                            threadClient.close();
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }

                });

            }
            // 如果是可写的
            else if (selectionKey.isWritable()) {
                client = (SocketChannel) selectionKey.channel();
                if (!client.isOpen())
                    return;
                client.write(send);
                client.close();
            }
        }
        catch (Exception e){
            if(client != null)
                client.close();
            throw e;

        }
    }


    public static void main(String[] args) {
        Connector connector = new Connector();
        int port = 8082;
        Controller.HOST = "http://localhost:" + port;
        try {
//            connector.test();
            connector.init(port);
            connector.listen();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

}


 

 

可以看到,上面的线程池,会先执行HTTP的解析,然后再执行Controller的业务。

首先看HTTPParse类



import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

public class HTTPParse {
    private static final char CR = '\r';
    private static final char LF = '\n';
    private static final char HT = '\t';
    private static final char SP = ' ';
    private static final char CO = ':';

    private StringBuilder sb = new StringBuilder();
    private HTTPRequest HTTPRequest = new HTTPRequest();

    enum State { STATUS_LINE,
        HEADER,
        HEADER_FOUND_CR,
        HEADER_FOUND_LF,
        HEADER_FOUND_CR_LF,
        HEADER_FOUND_CR_LF_CR,
        FINISHED }

    private HTTPParse.State state = HTTPParse.State.STATUS_LINE;

    public HTTPRequest parse(ByteBuffer input) throws IOException {
        assert input != null;
//        if(input.hasRemaining()){
//            System.out.println(new String(input.array()));
//        }
        while (input.hasRemaining() && state != State.FINISHED) {
            switch (state) {
                // 请求行
                case STATUS_LINE:
                    readResumeStatusLine(input);
                    break;
                case HEADER:
                    readResumeHeader(input);
                    break;
                case HEADER_FOUND_CR:
                case HEADER_FOUND_LF:
                    resumeOrLF(input);
                    break;
                case HEADER_FOUND_CR_LF:
                    resumeOrSecondCR(input);
                    break;
                case HEADER_FOUND_CR_LF_CR:
                    resumeOrEndHeaders(input);
                    break;
                default:
                    throw new InternalError(
                            "Unexpected state: " + String.valueOf(state));
            }
        }
        return HTTPRequest;
    }

    private void resumeOrLF(ByteBuffer input) throws ProtocolException {
        // 回车符后换行符
        char c = (char)input.get();
        if(c != LF && c != CR)
            throw new ProtocolException("这协议不对劲");
        state = State.HEADER_FOUND_CR_LF;
    }


    private void readResumeHeader(ByteBuffer input) throws ProtocolException {
        String name = null;
        String content;
        boolean first = true;
        while (input.hasRemaining()) {
            char c = (char)input.get();

            if (c == CR)
                break;
            else if (c == HT)
                c = SP;
            // 遇到冒号(只允许第一次)就获取header的名字
            else if(c == CO) {
                // 第二次匹配到冒号就当成正常的字符
                if(first) {
                    first = false;
                    name = sb.toString();
                    sb = new StringBuilder();
                    continue;
                }
            }
            else if(c == SP)
                continue;

            sb.append(c);
        }
        content = sb.toString();
        sb = new StringBuilder();
        parseHeaders(name, content);
        state = State.HEADER_FOUND_CR;
    }

    private void resumeOrEndHeaders(ByteBuffer input) throws ProtocolException {
        // 先把最后的/n去掉
        input.get();
        int length = input.limit() - input.position();
        byte[] body = new byte[length];
        int index = 0;
        // 剩下的数据当成body
        while (input.hasRemaining()){
            body[index++] = input.get();
        }
        HTTPRequest.setBody(body);
        state = State.FINISHED;
    }

    /**
     *  到这里代表已经一行结束,接着检测下行开头
     * @param input
     * @throws ProtocolException
     */
    private void resumeOrSecondCR(ByteBuffer input) throws ProtocolException {
        char c = (char)input.get();
        if(c == CR) {
            state = State.HEADER_FOUND_CR_LF_CR;
        } else if (c == SP || c == HT) {
            state = State.HEADER;
        }else{
            // 若是普通字符,放回去,打扰了
            sb.append(c);
            state = State.HEADER;
        }
    }

    private void readResumeStatusLine(ByteBuffer input) throws IOException {
        // 获取method
        char c;
        while (input.hasRemaining() && (c =(char)input.get()) != SP)
            sb.append(c);
        parseMethod(sb.toString());
        sb = new StringBuilder();
        // 获取url
        while (input.hasRemaining() && (c =(char)input.get()) != SP)
            sb.append(c);
        parseURL(sb.toString());
        sb = new StringBuilder();

        // 获取协议版本。。。算了
        while (input.hasRemaining() && (c =(char)input.get()) != CR);
        // 遇到换行符下一个必须是LF
        if((char)input.get() != LF)
            throw new ProtocolException("这协议不对劲");
        // 清空
        sb = new StringBuilder();
        // 下面读取的就是http头部信息
        state = State.HEADER;
    }



    private void parseURL(String url) throws MalformedURLException {
        HTTPRequest.setUrl(url);
    }

    /**
     * 解析方法
     * @param method
     * @throws ProtocolException
     */
    public void parseMethod(String method) throws ProtocolException {
        if(method.equals("GET")){
            HTTPRequest.setMethod(HTTPRequest.GET);
        }
        else if(method.equals("POST")){
            HTTPRequest.setMethod(HTTPRequest.POST);
        }else{
            throw new ProtocolException("请求方法不是GET或者POST,再你妈的见");
        }
    }

    private Map<String, String> parseCookie(String cookieStr){
        Map<String, String> cookie = new HashMap<>();
        String[] cookeis = cookieStr.split(";");
        // 遍历所有cookie
        for (int i = 0; i < cookeis.length; i++) {
            String[] cookieItem = cookeis[i].split("=");
            // 冒号前半部分作为key,后半部分作为value
            if(cookieItem.length >= 2)
                cookie.put(cookieItem[0], cookieItem[1]);
        }
        return cookie;
    }

    public void parseHeaders(String name, String content){
        assert name != null;
        if(name.equals("Host")){
            HTTPRequest.setHost(content);
        }else if(name.equals("Cookie")){
            HTTPRequest.setCookie(parseCookie(content));
        }else{
            // 其他请求头就不处理了,再见
        }
    }

    class ProtocolException extends IOException{
        public ProtocolException(String message) {
            super(message);
        }
    }
}

 

我们知道,http报文的请求格式是这样的

因此解析器就像上面,使用一个while循环,知道解析结束,switch中的STATUS_LINE就代表第一行状态行的解析,HEADER就是对请求头的解析等。如果协议错误,会抛出ProtocolException。

 

HTTPRequest

import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;

public class HTTPRequest {

    public static final int POST = 0;
    public static final int GET = 1;

    private String url;
    private int method;
    private String host;
    private String location;
    private byte[] body;
    private Map<String, String> cookie = new HashMap<>();

    public String getUrl() {
        return url;
    }

    public void setUrl(String url) {
        this.url = url;
    }

    public int getMethod() {
        return method;
    }

    public void setMethod(int method) {
        this.method = method;
    }

    public String getLocation() {
        return location;
    }

    public void setLocation(String location) {
        this.location = location;
    }

    public String getHost() {
        return host;
    }

    public void setHost(String host) {
        this.host = host;
    }

    public Map<String, String> getCookie() {
        return cookie;
    }

    public void setCookie(Map<String, String> cookie) {
        this.cookie = cookie;
    }

    public byte[] getBody() {
        return body;
    }

    public void setBody(byte[] body) {
        this.body = body;
    }

    @Override
    public String toString() {
        return "HTTPRequest{" +
                "url='" + url + '\'' +
                ", method=" + method +
                ", host='" + host + '\'' +
                ", cookie=" + cookie +
                '}';
    }
}

这里创建一个简单的http请求对象类,看代码可以发现,这里只是简单处理了POST跟GET请求,并且header部分只处理了host,location以及cookie字段,其他部分不做解析。由于一个线程处理一个request,所以这里的cookie简单地使用了HashMap而已

 

HTTPResponse

import java.text.SimpleDateFormat;
import java.util.*;

public class HTTPResponse {
    private String line;
//    private String headers;
    private Map<String, String> headers = new HashMap<>();
    private byte[] body;
    private String end;
    private Map<String, String> cookie = new HashMap<>();

    int status;

    HTTPResponse(){
        this.line = "HTTP/1.1 200 OK\r\n";
        headers.put("Server", "Nginx");
        headers.put("Content-Type", "text/html; charset=UTF-8");
        headers.put("Connection", "keep-alive");
        headers.put("Date", getDate());

        this.end = "\r\n\r\n";
    }

    /**
     * 获取符合http协议的Date字段的时间格式,时区是GMT
     * @return
     */
    private String getDate(){
        Date date = new Date();
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("E, d MMM yyyy HH:MM:ss z", Locale.ENGLISH);
        simpleDateFormat.setTimeZone(TimeZone.getTimeZone("GMT"));
        String now = simpleDateFormat.format(date);
        return now;
    }

    public Map<String, String> getCookie() {
        return cookie;
    }

    public void setCookie(Map<String, String> cookie) {
        this.cookie = cookie;
    }

    public HTTPResponse setContentType(String contentType){
        headers.put("Content-Type", contentType);
        return this;
    }


    public HTTPResponse setStatus(int status){
        if(status == 200){
            this.status = 200;
            this.line = "HTTP/1.1 200 OK\r\n";
        }else if(status == 302){
            this.status = 302;
            this.line = "HTTP/1.1 302 Found\r\n";
        }
        else{
            this.status = 404;
            this.line = "HTTP/1.1 404 Not Found\r\n";
            String html = "<!DOCTYPE html>" +
                    "<html>" +
                    "<head>" +
                    "<meta charset=\"utf-8\">" +
                    "<title></title>" +
                    "</head>" +
                    "<body>" +
                    "<h1> 404!该页面不存在!</h1>" +
                    "</body>" +
                    "</html>";
            this.body = html.getBytes();
        }
        return this;
    }

    public int getStatus() {
        return status;
    }

    public void addHeaders(String key, String value) {
        this.headers.put(key, value);
    }

    public HTTPResponse setBody(byte[] body) {
        this.body = body;
        return this;
    }

    private String getHeadersStr(){
        StringBuilder sb = new StringBuilder();
        headers.forEach((key, value) ->{
            sb.append(key).append(": ").append(value).append("\r\n");
        });
        if(status == 200) {
            cookie.forEach((key, value) -> {
                sb.append("Set-Cookie: ");
                sb.append(key).append("=").append(value).append(";").append("\r\n");
            });
        }
        sb.append("\r\n");
        return sb.toString();
    }

    public byte[] toBytes(){

        byte[] lineBytes = line.getBytes();
        byte[] headersBytes = getHeadersStr().getBytes();
        byte[] endBytes = end.getBytes();

        // 将几个字节数组相加
        byte[] result = new byte[lineBytes.length + headersBytes.length + body.length + endBytes.length];
        System.arraycopy(lineBytes, 0, result, 0, lineBytes.length);
        System.arraycopy(headersBytes, 0, result, lineBytes.length, headersBytes.length);
        System.arraycopy(body, 0, result, lineBytes.length + headersBytes.length, body.length);
        int length = lineBytes.length + headersBytes.length + body.length;
        System.arraycopy(endBytes, 0, result, length, endBytes.length);
        return result;
    }
}

 

Resopnse实体,主要有响应码,这里使用了三个常用的响应码,一个是200,一个是404,另一个是302临时重定向。并设置几个常见的http相应头,Server,Content-Type,Connection,Date以及Set-Cookie。但实际上不止这几个,后面会根据302,设置localtion的请求头属性,来通知浏览器跳转

 

Controller

 





import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

public class Controller {

    public static String HOST = "http://localhost:8082";
    public static String STATIC_DIR = "templates";
    public Logger logger = Logger.getLogger(Controller.class.getName());
    // 注意这里的session没有封装到request里
    private Map<String, Object> session;
    private String sessionid;
    // 客户端的channel
    private SocketChannel client;
    private HTTPResponse response;
    // 该属性标志url请求的是否是静态资源
    private boolean isStatic;
    private String url;


    public void control(HTTPRequest request, SocketChannel client) throws IOException {
        // 初始化,用来设置类属性
        init(request,client);

        // 首页路径处理
        if(url.equals("/")){
             response = getFileResponse("/index.html");
        }
        else if(url.endsWith("/index")) {
            response = getFileResponse(url + ".html");
        }
        // 登录页处理,注册功能懒得写,下次一定
        else if (url.endsWith("/login")){
            if(request.getMethod() == HTTPRequest.POST){
                Map<String, String> params = parsePostParam(request);
                // 如果登录成功
                if (login(params)) {
                    response = new HTTPResponse().setStatus(302).setBody(new byte[2]);
                    response.addHeaders("Location", HOST + "/user/index.html");
//                    response = getFileResponse("/user/index.html");
                }else{
                    // 懒得改前端,就多个页面吧
                    response = getFileResponse("/indexerror.html");
                }
            }else if(request.getMethod() == HTTPRequest.GET){
                response = getFileResponse("/index.html");
//                response = new HTTPResponse().setStatus(302).setBody(new byte[2]);
//                response.addHeaders("Location", HOST + "/index.html");
            }
        }
        // 个人页处理
        else if(url.startsWith("/user/")) {
            if (!isStatic && session.get("_user") == null) {
                response = getFileResponse(url);
                // 如果是html文件,且存在(这么写的原因是isStatic判断并不完全准确)
                if(url.endsWith(".html") && response.getStatus() != 404){
                    response.setStatus(302).setBody(new byte[2]);
                    response.addHeaders("Location", HOST + "/index");
                }
            } else {
                response = getFileResponse(url);
                response.setStatus(200);
            }
        }
        // 直接获取文件页面的处理,如静态资源与html
        else{
            response = getFileResponse(url);
        }


        // 设置set_cookie
        if(!isStatic) {
            Map<String, String> setCookie = response.getCookie();
            setCookie.put("session_id", sessionid);
        }

        sendData(response);
    }

    private void init(HTTPRequest request, SocketChannel client){
        this.client = client;
        // url处理
        url = request.getUrl();
        logger.info(Thread.currentThread().toString() + "\t访问路径为:" +  url);
        if(url.contains("?")){
            // 不匹配参数,打扰了
            String[] split = url.split("[?]");
            url = split[0];
        }
        // 看看url是不是要获取的是静态文件,是的话放行
        isStatic = isStaticResource(url);

        // 获取session
        Map<String, String> cookie = request.getCookie();
        sessionid = cookie.get("session_id");

        if(!isStatic){
            // 验证并获取一个有效的sessionid
            sessionid = SessionFactory.getInstance().getVaildSessionId(sessionid);
            session = SessionFactory.getInstance().getSession(sessionid);
        }
    }

    /**
     * 判断url是否要获取静态资源
     * @param url
     * @return
     */
    private boolean isStaticResource(String url){
        if(url.endsWith(".css") || url.endsWith(".js")
                || url.endsWith(".png") || url.endsWith(".jpg")
                || url.endsWith(".jpeg") || url.endsWith(".woff2")
                || url.endsWith(".woff") || url.endsWith(".ttf")){
            return true;
        }
        return false;
    }

    /**
     * 根据传入的文件路径获取对应的文件,需要注意path第一字符为/
     * @param path
     * @return
     * @throws IOException
     */
    private HTTPResponse getFileResponse(String path) throws IOException {
        // 读取对应的html文件作为body,path自带/
        File tempfile = new File(STATIC_DIR + path);
        File file;
        FileInputStream fileInputStream = null;
        byte[] body;
        try {
            // 将文件读取到body数组中,作为response的body
            file = tempfile.getCanonicalFile();
            fileInputStream = new FileInputStream(file);
            body = fileInputStream.readAllBytes();
        }catch (FileNotFoundException e){
            // 文件不存在,直接404警告
            HTTPResponse httpResponse = new HTTPResponse().setStatus(404);
//            sendData(httpResponse);
            return httpResponse;
        }
        finally {
            if(fileInputStream != null)
                fileInputStream.close();
        }

        // 设置response的各种参数
        HTTPResponse httpResponse = new HTTPResponse().setBody(body).setStatus(200);
        // 对静态资源设置不同的contentType
        if(path.endsWith(".css")) {
            httpResponse.setContentType("text/css");
        }else if(path.endsWith(".js")){
            httpResponse.setContentType("application/javascript");
        }else if(path.endsWith(".png")){
            httpResponse.setContentType("image/png");
        }else if(path.endsWith(".jpg")){
            httpResponse.setContentType("image/jpg");
        }else if(path.endsWith(".jpeg")){
            httpResponse.setContentType("image/jpeg");
        }else if(path.endsWith(".woff2") || path.endsWith(".woff") || path.endsWith(".ttf")){
            httpResponse.setContentType("");
        }

        return httpResponse;
    }

    /**
     * 用于发送response给客户端
     * @param httpResponse
     * @throws IOException
     */
    private void sendData(HTTPResponse httpResponse) throws IOException {
        byte[] response = httpResponse.toBytes();
        ByteBuffer send = ByteBuffer.allocate(response.length);
        send.limit(response.length);
        send.put(response);
        send.position(0);
        client.write(send);
    }

    private Map<String, String> parsePostParam(HTTPRequest request){
        HashMap<String, String> result = new HashMap<>();
        String params = new String(request.getBody());
        String[] split = params.split("&");
        if(split.length >= 2){
            // 遍历所有参数
            for (int i = 0; i < split.length; i++) {
                String[] param = split[i].split("=");
                // 把参数的key与value写入map
                if(param.length >= 2){
                    result.put(param[0], param[1]);
                }
            }
        }
        return result;
    }

    /**
     * 验证登录用户,并放到session中。这里使用文件读写来实现账号保存
     * @param params
     * @return
     */
    public boolean login(Map<String, String> params){
        String username = params.get("username").trim();
        String password = params.get("password").trim();
        UserSet.User user = UserSet.getInstance().getUser(username, password);
        if(user != null){
            session.put("_user", user);
            return true;
        }
        return false;

    }

}

 

Controller是业务执行逻辑,跟SpringMVC的Controller不同,这里的Controller对象不是单例的,而是每个线程都会新建一个。用来处理业务访问以及登录的逻辑。其中的control是主要执行的逻辑,线程池都是从这里执行的。后面的SessionFactory跟UserSet也是由这个类直接操纵

 

SessionFactory

 

import java.util.Date;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

public class SessionFactory {
    volatile static private SessionFactory instance;
    private Map<String, Map<String,Object>> sessions = new ConcurrentHashMap<>();
    // 用来存放session的过期时间,session不给续约一说,一旦过期,就只能更换
    private Map<String, Long> expires = new ConcurrentHashMap<>();

    private SessionFactory(){}

    public static SessionFactory getInstance(){
        if (instance == null){
            synchronized (SessionFactory.class){
                if(instance == null){
                    instance = new SessionFactory();
                }
            }
        }
        return instance;
    }

    /**
     * 敲黑板,这不算传统的工厂模式直接返回对象,而是普通的get方法
     * @param id
     * @return
     */
    public Map<String, Object> getSession(String id){
        return sessions.get(id);
    }

    /**
     * 敲黑板,返回的是session_id,这才是传统工厂模式返回的对象
     * @return
     */
    public String createSession(){
        String id = UUID.randomUUID().toString().replace("-", "");
        sessions.put(id, new ConcurrentHashMap<>());
        long now = new Date().getTime();
        // 简单写死,设定为3天后过期
        long expireTime  = now + TimeUnit.DAYS.toMillis(3);
        expires.put(id, expireTime);
        return id;
    }

    /**
     * 如果返回值为true代表过期
     * @param id
     * @return
     */
    public boolean checkExpire(String id){
        long now = new Date().getTime();
        long expireTime = expires.get(id);
        return now > expireTime;
    }


    /**
     * 如果session_id过期,返回新的session_id,并将过期的处理掉
     * @return
     */
    public String getVaildSessionId(String id){
        // 如果sessionid不存在,返回一个有效的sessionid
        if(id == null || sessions.get(id) == null) {
            return createSession();
        }else {
            // 如果sessionid过期,清除并返回一个有效的sessionid
            if (checkExpire(id))
                return flush(id);
            return id;
        }
    }


    public String flush(String id){
        Map<String, Object> oldSession = sessions.get(id);
        sessions.remove(id);
        expires.remove(id);
        String newId = createSession();
        sessions.put(newId, oldSession);
        return newId;
    }

}

SessionFactory人如其名,用来生成对应Session,它是一个单例工厂。所有线程共用一个SessionFactory,保证Session一致性。由于使用到了多线程,这里使用了双重锁校验,另外session过期时间默认为3天。但跟tomcat的JSESSION一样,没有通知客户端中session_id的过期时间。

 

UserSet

 

import java.io.*;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;

public class UserSet {
    private volatile static UserSet instance;
    public static String FILE_NAME = "user.txt";
    public static String SPNARATOR = "/%abc%/";
    private Set<User> userSet = new CopyOnWriteArraySet<>();

    private UserSet(){}

    public static UserSet getInstance() {
        if(instance == null){
            synchronized (UserSet.class){
                if(instance == null){
                    instance = new UserSet();
                }
            }
        }
        return instance;
    }

    // 将文件内容反序列化到userSet中
    void read(){
        File file = new File(FILE_NAME);
        if(!file.exists()) {
            return;
        }
        try {
            Reader reader = new InputStreamReader(new FileInputStream(file));
            BufferedReader br = new BufferedReader(reader);
            String line = br.readLine();
            while (line != null){
                String[] split = line.split(SPNARATOR);
                if(split.length >= 2){
                    User user = new User();
                    user.username = split[0];
                    user.password = split[1];
                    userSet.add(user);
                }
                line = br.readLine();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public User getUser(String username, String password){
        if(contanis(username, password)){
            return new User(username, password);
        }
        return null;
    }

    public boolean contanis(String username, String password){
        User user = new User();
        user.username = username;
        user.password = password;
        if(userSet.size() == 0){
            read();
        }
        return userSet.contains(user);
    }

    public class User{
        private String username;
        private String password;

        public User(String username, String password) {
            this.username = username;
            this.password = password;
        }
        public User(){};

        @Override
        public boolean equals(Object obj) {
            User user = (User) obj;
            return username.equals(user.username) &&
                    password.equals(user.password);
        }

        @Override
        public int hashCode() {
            int h1 = username.hashCode();
            int h2 = password.hashCode();
            // 随便写的,我也不知道冲突率怎样。
            int h = h1 ^ h2;
            return (h) ^ (h >>> 16);
        }

        @Override
        public String toString() {
            return "User{" +
                    "username='" + username + '\'' +
                    ", password='" + password + '\'' +
                    '}';
        }
    }
}

UserSet是一个单例对象,用来存放所有的User对象的集合。它会被Controller调用来验证账号密码。UserSet把默认为user.txt里的文件数据按格式反序列化成User对象。另外内部类定义了一个User对象,重写了equals方法,使其只要username跟password属性相等,就认为是同一个对象。另外,由于时间关系,这服务器没有注册方法,也就是没有写文件写入的操作。有需要新账号的自己手动写到user.txt文件中

 

结语

原理说起来简单,但是实现细节方面还是遇到了不少bug,并且代码方面还有很多可以优化的地方。毕竟连注册功能都没有,懒得写。

不得不吐槽的是,java的nio太不人性化了,确实反人类,难怪一堆人吐槽java的nio难写,转而使用netty。

另外,使用多线程来执行nio,还是遇到了许多的bug,比如AsynchronousCloseException, ClosedChannelException,CancelledKeyException等等一堆在单线程时看不到的异常,调试了挺久才解决的,暂时没遇到了,但不能保证后续绝对没有bug,姑且是改成了遇到这些异常也能正常处理后续请求。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值