和https的实现一样,客户端生成对称加密密钥,通过公钥加密对称密钥,传输给服务端。这里不是本文的重点,就不详细讲述了。
schema.lua
配置参数:
limit_second:有效时间(限制客户端一直重复请求)
inner_host:内网host (配置了则不需要验签以及解密)
local typedefs = require "kong.db.schema.typedefs"
return {
name = "api-safety",
fields = {
{ consumer = typedefs.no_consumer },
{ protocols = typedefs.protocols_http },
{ config = {
type = "record",
fields = {
{limit_second = {type = "number", required = true},},
{inner_host = {type = "array", elements = {type = "string"}}}
}, }, },
},
}
access.lua
local safety = require "safety_lua" -- 加解密
--local version = require "version"
local multipart = require "multipart"
local cjson = require "cjson"
local pl_tablex = require "pl.tablex"
local math = require "math"
local kong = kong
local table_insert = table.insert
local get_uri_args = kong.request.get_query
local set_uri_args = kong.service.request.set_query
local clear_header = kong.service.request.clear_header
local get_header = kong.request.get_header
local set_header = kong.service.request.set_header
local clear_header = kong.service.request.clear_header
local get_headers = kong.request.get_headers
local set_headers = kong.service.request.set_headers
local set_method = kong.service.request.set_method
local get_method = kong.request.get_method
local get_raw_body = kong.request.get_raw_body
local set_raw_body = kong.service.request.set_raw_body
local encode_args = ngx.encode_args
local ngx_decode_args = ngx.decode_args
local type = type
local str_find = string.find
local pcall = pcall
local pairs = pairs
local error = error
local rawset = rawset
local pl_copy_table = pl_tablex.deepcopy
local _M = {}
local DEBUG = ngx.DEBUG
local CONTENT_LENGTH = "content-length"
local CONTENT_TYPE = "content-type"
local HOST = "host"
local JSON, MULTI, ENCODED = "json", "multi_part", "form_encoded"
local EMPTY = pl_tablex.readonly({})
local function parse_json(body)
if body then
local status, res = pcall(cjson.decode, body)
if status then
return res
end
end
end
local function decode_args(body)
if body then
return ngx_decode_args(body)
end
return {}
end
local function get_content_type(content_type)
if content_type == nil then
return
end
if str_find(content_type:lower(), "application/json", nil, true) then
return JSON
elseif str_find(content_type:lower(), "multipart/form-data", nil, true) then
return MULTI
elseif str_find(content_type:lower(), "application/x-www-form-urlencoded", nil, true) then
return ENCODED
end
end
local function in_array(val,list)
if not list then
return false
end
if list then
for _, v in pairs(list) do
if v == val then
return true
end
end
end
end
local function table_sort_by_key(t)
local a = {}
for n in pairs(t) do
a[#a+1] = n
end
table.sort(a)
local i = 0
return function()
i = i + 1
return a[i], t[a[i]]
end
end
local function table_remove_by_key(tbl,key)
local tmp ={}
for i in pairs(tbl) do
table.insert(tmp,i)
end
local newTbl = {}
local i = 1
while i <= #tmp do
local val = tmp [i]
if val == key then
table.remove(tmp,i)
else
newTbl[val] = tbl[val]
i = i + 1
end
end
return newTbl
end
local function decrypt_and_verify(conf, params, secret)
--拿到服务器密钥
new_secret = rsa.decrypt(secret)
local decrypt_params = safety.decrypt(params, new_secret)
if decrypt_params == nil then
return -401,{},"Decrypt error"
end
local new_parameters = parse_json(decrypt_params)
if new_parameters == nil then
return -402,{},"Incorrect data format"
end
local timestamp = new_parameters["timestamp"]
if timestamp == nil then
timestamp = 0
end
local between_secend = ngx.now() - timestamp/1000
if math.abs(between_secend) > conf.limit_second then
return -403,{},"Request time is invalid"
end
local sign = new_parameters['token']
if sign == nil then
return -404,{},"Sign is empty"
end
new_parameters = table_remove_by_key(new_parameters, 'token')
local params_str = ''
for _,v in table_sort_by_key(new_parameters) do
params_str = params_str .. v
end
local check_sign = safety.signature(params_str)
if sign ~= check_sign then
return -405,{},"Sign is error"
end
return 200, new_parameters, ""
end
local function transform_querystrings(conf, header)
local query_string = get_uri_args()
local params = query_string["params"]
if params == nil or params == '' then
return -400,{},"Params is required"
end
local err_code,parameters,err_msg = decrypt_and_verify(conf, params, header['Secret'])
if err_code ~= 200 then
if err_code == -403 then
return kong.response.exit(480, { code = -480, data = {}, msg = err_msg })
else
return kong.response.exit(200, { code = err_code, data = {}, msg = err_msg })
end
else
set_uri_args(parameters)
end
end
local function transform_json_body(conf, body, content_length, header)
local content_length = (body and #body) or 0
local table_params = parse_json(body)
if content_length == 0 then
return -400,{},"Body is empty"
end
local params = table_params["params"]
if params == nil or params == '' then
return -400,{},"Params is required"
end
local err_code,parameters,err_msg = decrypt_and_verify(conf, params, header['Secret'])
if err_code ~= 200 then
return err_code,parameters,err_msg
else
return 200, cjson.encode(new_parameters), ""
end
end
local function transform_url_encoded_body(conf, body, content_length, header)
local table_params = decode_args(body)
if content_length == 0 then
return -400,{},"Body is empty"
end
local params = table_params["params"]
if params == nil or params == '' then
return -400,{},"Params is required"
end
local err_code,parameters,err_msg = decrypt_and_verify(conf, params)
if err_code ~= 200 then
return err_code,parameters,err_msg
else
return 200, encode_args(parameters), ""
end
end
local function transform_multipart_body(conf, body, content_length, content_type_value)
if content_length == 0 then
return -400,{},"Body is empty"
end
local mul_parameters = multipart(body and body or "", content_type_value)
local params = mul_parameters:get("params").value
if params == nil or params == '' then
return -400,{},"Params is required"
end
local err_code,parameters,err_msg = decrypt_and_verify(conf, params, header['Secret'])
if err_code ~= 200 then
return err_code,parameters,err_msg
else
mul_parameters:delete("params")
for key, value in pairs(parameters) do
mul_parameters:set_simple(key, value)
end
return 200, mul_parameters:tostring(), ""
end
end
local function transform_body(conf, header)
local content_type_value = get_header(CONTENT_TYPE)
local content_type = get_content_type(content_type_value)
if content_type == nil then
return
end
local body = get_raw_body()
local content_length = (body and #body) or 0
local msg = ""
local error_code = 200
if content_type == ENCODED then
error_code, body, msg = transform_url_encoded_body(conf, body, content_length, header)
elseif content_type == MULTI then
error_code, body, msg = transform_multipart_body(conf, body, content_length, content_type_value)
elseif content_type == JSON then
error_code, body, msg = transform_json_body(conf, body, content_length, header)
end
if error_code == 200 then
set_raw_body(body)
set_header(CONTENT_LENGTH, #body)
else
if error_code == -403 then
return kong.response.exit(480, { code = -480, data = {}, msg = msg })
else
return kong.response.exit(200, { code = error_code, data = {}, msg = msg })
end
end
end
function _M.execute(conf)
clear_header("Need-Sign") --防止外部输入Need-Sign的头,绕过验签
local host = kong.request.get_host()
local headers = get_headers()
if in_array(host, conf.inner_host) then
set_header("Need-Sign", 0) --内网域名不需要验签
else
--local ctx = ngx.ctx
local method = kong.request.get_method()
if method == "POST" then
transform_body(conf, headers)
elseif method == "GET" then
transform_querystrings(conf, headers)
else
return
end
end
end
return _M
body_filter.lua
local safety = require "safety_lua"
local cjson = require "cjson"
local concat = table.concat
local str_find = string.find
local CONTENT_LENGTH = "content-length"
local CONTENT_TYPE = "content-type"
local JSON, MULTI, ENCODED = "json", "multi_part", "form_encoded"
local _M = {}
local function get_content_type(content_type)
if content_type == nil then
return
end
if str_find(content_type:lower(), "application/json", nil, true) then
return JSON
elseif str_find(content_type:lower(), "multipart/form-data", nil, true) then
return MULTI
elseif str_find(content_type:lower(), "application/x-www-form-urlencoded", nil, true) then
return ENCODED
end
end
local function parse_json(body)
if body then
local status, res = pcall(cjson.decode, body)
if status then
return res
end
end
end
local function transform_body()
local ctx = ngx.ctx
local chunk, eof = ngx.arg[1], ngx.arg[2]
ctx.rt_body_chunks = ctx.rt_body_chunks or {}
ctx.rt_body_chunk_number = ctx.rt_body_chunk_number or 1
if eof then
local chunks = concat(ctx.rt_body_chunks)
local body = chunks
if ctx.use_flag == true then
if parse_json(body) == nil then
--后端返回非json数据,不进行加密
ngx.arg[1] = body
else
ngx.arg[1] = safety.encrypt(body, secret)
end
else
ngx.arg[1] = body
end
else
ctx.rt_body_chunks[ctx.rt_body_chunk_number] = chunk
ctx.rt_body_chunk_number = ctx.rt_body_chunk_number + 1
ngx.arg[1] = nil
end
end
function _M.execute()
transform_body()
end
return _M
header_filter.lua
ps:因为加密了,所以content-length会改变,切记需要删除!!!
local kong = kong
local math = require "math"
local _M = {}
local function transform_header()
local ctx = ngx.ctx
ngx.update_time()
local systemTime = math.ceil(ngx.now() * 1000)
if ctx.use_flag == true then
kong.response.set_header("Content-Type", "text/plain")
kong.response.set_header("Encrypt-Flag", 1)
kong.response.clear_header("Content-Length")
else
kong.response.set_header("Encrypt-Flag", 0)
end
kong.response.set_header("System-Time", systemTime) -- 返回系统时间,用于客户端校正时间
end
function _M.execute()
transform_header()
end
return _M
handler.lua
local access = require "kong.plugins.api-safety.access"
local body_filter = require "kong.plugins.api-safety.body_filter"
local header_filter = require "kong.plugins.api-safety.header_filter"
local SafetyHandler = {}
-- 请求时的处理过程
function SafetyHandler:access(conf)
access.execute(conf)
end
function SafetyHandler:header_filter()
header_filter.execute()
end
function SafetyHandler:body_filter()
body_filter.execute()
end
-- PRIORITY 越大执行顺序越靠前
SafetyHandler.PRIORITY = 800
SafetyHandler.VERSION = "1.0.0"
return SafetyHandler