项目实训(7) - 简易网关


前言

简易网关实现。

Filters

CorsFilter

/**
 * 设置跨域
 */
@Order(0)
@Component
@WebFilter(filterName = "CorsFilter", urlPatterns = {"/*"})
public class CorsFilter implements Filter {

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        // 跨域设置
        response.setHeader("Access-Control-Allow-Methods", "*");
        response.setHeader("Access-Control-Allow-Credentials", "true");
        response.setHeader("Access-Control-Allow-Origin", "*");
        response.setHeader("Access-Control-Allow-Headers", "Content-Type,Authorization");
        response.setHeader("Access-Control-Expose-Headers", "*");

        filterChain.doFilter(request, response);
    }

}

ResponseFilter

/**
 * 设置响应
 */
@Order(1)
@Component
@WebFilter(filterName = "ResponseFilter", urlPatterns = {"/*"})
public class ResponseFilter implements Filter {

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        if (Objects.equals(request.getMethod(), HttpMethod.OPTIONS.name())) {
            response.setStatus(200);
        } else {
            // 数据量大于8MB
            if (request.getContentLengthLong() > 8 * 1024 * 1024) {
                ResponseData responseData = new ResponseData(400, "数据过大");
                byte[] bytes = JSON.toJSONBytes(responseData);
                request.setAttribute("bytes", bytes);
                response.setContentType(MediaType.APPLICATION_JSON_VALUE);
            } else {

                filterChain.doFilter(request, response);
            }

            byte[] bytes = (byte[]) request.getAttribute("bytes");
            OutputStream out = response.getOutputStream();
            out.write(bytes);
            out.flush();
            out.close();

        }

    }
}

AuthorizationFilter

/**
 * 用于验证权限
 */
@Order(2)
@Component
@WebFilter(filterName = "AuthorizationFilter", urlPatterns = {"/*"})
public class AuthorizationFilter implements Filter {

    @Autowired
    private JwtTokenUtil jwtTokenUtil;

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        String uri = request.getRequestURI();

        if (!uri.contains("/login_and_register/login")
                && !uri.contains("/login_and_register/register")
                && !uri.contains("/login_and_register/code")
                && !uri.contains("/file/pic")) {
            // 检验token
            String authorization = request.getHeader("Authorization");
            if (authorization == null || !jwtTokenUtil.validateToken(authorization)) {
                // 消息
                ResponseData responseData = new ResponseData(401, "未登录");
                byte[] bytes = JSON.toJSONBytes(responseData);
                request.setAttribute("bytes", bytes);
                response.setContentType("application/json");
                response.setStatus(401);
                return;
            }
        }

        filterChain.doFilter(request, response);
    }
}

RoutingFilter

/**
 * 路由转发
 */
@Order(3)
@Component
@WebFilter(filterName = "RoutingFilter", urlPatterns = {"/*"})
public class RoutingFilter implements Filter {

    @Autowired
    private Router router;

    @Autowired
    private RestTemplate restTemplate;

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        ResponseData responseData = null;

        String uri = request.getRequestURI();
        Service service = getService(router.getServices(), uri);
        if (service == null) {
            responseData = new ResponseData(404, "不存在的服务");
            response.setStatus(404);
        } else {
            String url = buildUrl(request, service.getNextIp(), service.getPath());
            try {
                RequestEntity<byte[]> requestEntity = buildRequestEntity(request, url);
                ResponseEntity<byte[]> exchange = restTemplate.exchange(requestEntity, byte[].class);
                byte[] bytes = exchange.getBody();
                if (bytes != null) {
                    request.setAttribute("bytes", bytes);
                    response.setContentType(exchange.getHeaders().getContentType() != null
                            ? exchange.getHeaders().getContentType().toString()
                            : MediaType.APPLICATION_JSON_VALUE);
                    return;
                } else {
                    responseData = new ResponseData(404, "资源不存在");
                }
            } catch (URISyntaxException exception) {
                responseData = new ResponseData(500, "服务器故障,uri创建失败");
                response.setStatus(500);
            } catch (ResourceAccessException exception) {
                responseData = new ResponseData(500, "响应超时");
                response.setStatus(500);
            } catch (HttpClientErrorException exception) {
                responseData = new ResponseData(exception.getRawStatusCode(), exception.getMessage());
                response.setStatus(exception.getRawStatusCode());
            }
        }

        byte[] bytes = JSON.toJSONBytes(responseData);
        request.setAttribute("bytes", bytes);
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);

    }

    /**
     * 根据请求中的部分参数获取对应的Service对象
     * @param services 服务列表
     * @param uri 原始uri
     * @return Service对象
     */
    private Service getService(List<Service> services, String uri) {
        Service service = null;

        for (Service temp : services) {
            if (temp.support(uri)) {
                service = temp;
                break;
            }
        }

        return service;
    }

    /**
     * 轮询ip地址
     * @param service 非空的Service对象
     * @return ip地址
     */
    private String getIp(Service service) {
        int index = service.getCounter().getAndIncrement();
        return service.getIp()[index % service.size()];
    }

    /**
     * 构造重定向的url
     * @param request 原始请求
     * @param ip 服务器ip
     * @param path 原始请求中的部分路径
     * @return 重定向的url
     */
    private String buildUrl(HttpServletRequest request, String ip, String path) {
        String query = request.getQueryString();
        return ip + request.getRequestURI().replace(path, "") +
                (query != null ? "?" + query : "");
    }

    /**
     * 创建RequestEntity<byte[]>
     * @param request 原始请求
     * @param url 重定向的url
     * @return RequestEntity\<byte[]\>对象
     * @throws IOException IO异常
     * @throws URISyntaxException 创建URI时抛出的异常
     */
    private RequestEntity<byte[]> buildRequestEntity(HttpServletRequest request, String url) throws IOException, URISyntaxException {
        HttpMethod method = HttpMethod.resolve(request.getMethod());
        HttpHeaders header = parseRequestHeaders(request);
        byte[] body = parseRequestBody(request);
        return new RequestEntity<byte[]>(body, header, method, new URI(url));
    }

    /**
     * 构建请求头
     * @param request 原始请求
     * @return 请求头对象
     */
    private HttpHeaders parseRequestHeaders(HttpServletRequest request) {
        // 构建请求头
        HttpHeaders headers = new HttpHeaders();
        Enumeration<String> headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String name = headerNames.nextElement();
            String value = request.getHeader(name);
            headers.add(name, value);
        }
        return headers;
    }

    /**
     * 构建请求体
     * @param request 原始请求
     * @return byte[]数组
     * @throws IOException IO异常
     */
    private byte[] parseRequestBody(HttpServletRequest request)
            throws IOException {
        InputStream input = request.getInputStream();
        return StreamUtils.copyToByteArray(input);
    }

    private byte[] read(InputStream input) {
        ByteArrayOutputStream output = null;
        try {
            output = new ByteArrayOutputStream();
            byte[] buffer = new byte[1024];
            int read = -1;
            while ((read = input.read(buffer)) != -1) {
                output.write(buffer, 0, read);
            }
            return output.toByteArray();
        } catch (IOException exception) {
            System.out.println(exception.getMessage());
            return null;
        } finally {
            try {
                if (input != null) {
                    input.close();
                }
                if (output != null) {
                    output.close();
                }
            } catch (IOException exception) {
                System.out.println(exception.getMessage());
            }
        }
    }

}

LoadBalancer

接口

public interface LoadBalancer {

    int getNextService(int length);

}

默认实现

/**
 * 默认的负载均衡处理器
 * 采用轮询方法
 */
public class DefaultLoadBalancer implements LoadBalancer {

    private final AtomicCounter counter;

    public DefaultLoadBalancer() {
        counter = new AtomicCounter();
    }

    @Override
    public int getNextService(int length) {
        return counter.getAndIncrement() % length;
    }

}

Service

public class Service {

    private String serviceName;
    private String path;
    private String[] ip;
    private final AtomicCounter counter = new AtomicCounter();
    private final LoadBalancer balancer = new DefaultLoadBalancer();

    public Service() {}

    public Service(String serviceName, String path, String[] ip) {
        this.serviceName = serviceName;
        this.path = path;
        this.ip = ip;
    }

    public String getServiceName() {
        return serviceName;
    }

    public void setServiceName(String serviceName) {
        this.serviceName = serviceName;
    }

    public String getPath() {
        return path;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public String[] getIp() {
        return ip;
    }

    public void setIp(String[] ip) {
        this.ip = ip;
    }

    public AtomicCounter getCounter() {
        return counter;
    }

    public int size() {
        return this.ip.length;
    }

    public boolean support(String uri) {
        return uri.startsWith(this.path);
    }

    public String getNextIp() {
        int index = balancer.getNextService(this.ip.length);
        return this.ip[index];
    }

    @Override
    public String toString() {
        return "Service{" +
                "serviceName='" + serviceName + '\'' +
                ", path='" + path + '\'' +
                ", ip=" + Arrays.toString(ip) +
                ", counter=" + counter +
                '}';
    }

}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

东羚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值